mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
msprobe/core/common/const.py
CHANGED
|
@@ -14,11 +14,13 @@ class Const:
|
|
|
14
14
|
REGEX_PREFIX_MAX_LENGTH = 20
|
|
15
15
|
REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
|
|
16
16
|
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
|
|
17
|
+
STRING_BLACKLIST = r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]"
|
|
17
18
|
COMMA = ","
|
|
18
19
|
FLOAT_EPSILON = np.finfo(float).eps
|
|
19
20
|
OFF = 'OFF'
|
|
20
21
|
BACKWARD = 'backward'
|
|
21
22
|
FORWARD = 'forward'
|
|
23
|
+
JIT = 'Jit'
|
|
22
24
|
PRIMITIVE_PREFIX = 'Primitive'
|
|
23
25
|
DEFAULT_LIST = []
|
|
24
26
|
DEFAULT_PATH = './'
|
|
@@ -30,6 +32,7 @@ class Const:
|
|
|
30
32
|
FOUR_SEGMENT = 4
|
|
31
33
|
SIX_SEGMENT = 6
|
|
32
34
|
SEVEN_SEGMENT = 7
|
|
35
|
+
MAX_DEPTH = 10
|
|
33
36
|
|
|
34
37
|
# dump mode
|
|
35
38
|
ALL = "all"
|
|
@@ -78,6 +81,7 @@ class Const:
|
|
|
78
81
|
RUN_UT = "run_ut"
|
|
79
82
|
GRAD_PROBE = "grad_probe"
|
|
80
83
|
TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
|
|
84
|
+
DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
|
|
81
85
|
LEVEL_L0 = "L0"
|
|
82
86
|
LEVEL_L1 = "L1"
|
|
83
87
|
LEVEL_L2 = "L2"
|
|
@@ -100,6 +104,30 @@ class Const:
|
|
|
100
104
|
CUDA_LOWERCASE = 'cuda'
|
|
101
105
|
DISTRIBUTED = 'Distributed'
|
|
102
106
|
|
|
107
|
+
# struct json param
|
|
108
|
+
ORIGIN_DATA = "origin_data"
|
|
109
|
+
SCOPE = "scope"
|
|
110
|
+
STACK = "stack"
|
|
111
|
+
|
|
112
|
+
ATEN = "Aten"
|
|
113
|
+
MODULE_WHITE_LIST = ["torch", "numpy"]
|
|
114
|
+
|
|
115
|
+
FUNC_SKIP_LIST = ["construct", "__call__"]
|
|
116
|
+
|
|
117
|
+
FILE_SKIP_LIST = ["site-packages/mindspore", "package/mindspore", "msprobe", "site-packages/torch", "package/torch"]
|
|
118
|
+
|
|
119
|
+
STACK_FILE_INDEX = 0
|
|
120
|
+
|
|
121
|
+
STACK_FUNC_INDEX = 2
|
|
122
|
+
|
|
123
|
+
STACK_FUNC_ELE_INDEX = 1
|
|
124
|
+
|
|
125
|
+
CONSTRUCT_NAME_INDEX = -3
|
|
126
|
+
|
|
127
|
+
NAME_FIRST_POSSIBLE_INDEX = -4
|
|
128
|
+
|
|
129
|
+
NAME_SECOND_POSSIBLE_INDEX = -5
|
|
130
|
+
|
|
103
131
|
INPLACE_LIST = [
|
|
104
132
|
"broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
|
|
105
133
|
"_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all",
|
|
@@ -114,6 +142,23 @@ class Const:
|
|
|
114
142
|
"int32_to_int64": ["cross_entropy"]
|
|
115
143
|
}
|
|
116
144
|
|
|
145
|
+
FILL_CHAR_NUMS = 50
|
|
146
|
+
TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
|
|
147
|
+
WITHOUT_CALL_STACK = "The call stack retrieval failed."
|
|
148
|
+
|
|
149
|
+
STEP = "step"
|
|
150
|
+
RANK = "rank"
|
|
151
|
+
HYPHEN = "-"
|
|
152
|
+
STEP_RANK_MAXIMUM_RANGE = [int(0), int(1e6)]
|
|
153
|
+
|
|
154
|
+
# data type const
|
|
155
|
+
FLOAT16 = "Float16"
|
|
156
|
+
FLOAT32 = "Float32"
|
|
157
|
+
BFLOAT16 = "BFloat16"
|
|
158
|
+
TORCH_FLOAT16 = "torch.float16"
|
|
159
|
+
TORCH_FLOAT32 = "torch.float32"
|
|
160
|
+
TORCH_BFLOAT16 = "torch.bfloat16"
|
|
161
|
+
|
|
117
162
|
|
|
118
163
|
class CompareConst:
|
|
119
164
|
"""
|
|
@@ -159,6 +204,7 @@ class CompareConst:
|
|
|
159
204
|
INPUT_STRUCT = "input_struct"
|
|
160
205
|
OUTPUT_STRUCT = "output_struct"
|
|
161
206
|
SUMMARY = "summary"
|
|
207
|
+
MAX_EXCEL_LENGTH = 1048576
|
|
162
208
|
|
|
163
209
|
COMPARE_RESULT_HEADER = [
|
|
164
210
|
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
|
|
@@ -197,6 +243,8 @@ class CompareConst:
|
|
|
197
243
|
ERROR = 'error'
|
|
198
244
|
SKIP = 'SKIP'
|
|
199
245
|
N_A = 'N/A'
|
|
246
|
+
INF = 'inf'
|
|
247
|
+
NEG_INF = '-inf'
|
|
200
248
|
BFLOAT16_MIN = -3.3895313892515355e+38
|
|
201
249
|
BFLOAT16_MAX = 3.3895313892515355e+38
|
|
202
250
|
BFLOAT16_EPS = 3.90625e-3 # 2 ** -8
|
|
@@ -274,7 +322,8 @@ class FileCheckConst:
|
|
|
274
322
|
MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
275
323
|
MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
276
324
|
MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
277
|
-
MAX_YAML_SIZE = 1048576 #
|
|
325
|
+
MAX_YAML_SIZE = 1048576 # 1 * 1024 * 1024
|
|
326
|
+
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
278
327
|
DIR = "dir"
|
|
279
328
|
FILE = "file"
|
|
280
329
|
DATA_DIR_AUTHORITY = 0o750
|
|
@@ -287,6 +336,7 @@ class FileCheckConst:
|
|
|
287
336
|
CSV_SUFFIX: MAX_CSV_SIZE,
|
|
288
337
|
YAML_SUFFIX: MAX_YAML_SIZE
|
|
289
338
|
}
|
|
339
|
+
CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
|
|
290
340
|
|
|
291
341
|
|
|
292
342
|
class OverflowConst:
|
|
@@ -329,11 +379,22 @@ class MsgConst:
|
|
|
329
379
|
"""
|
|
330
380
|
Class for log messages const
|
|
331
381
|
"""
|
|
332
|
-
CLEAR_SYMBOL = "\033[K"
|
|
333
382
|
MSPROBE_LOG_LEVEL = "MSPROBE_LOG_LEVEL"
|
|
334
|
-
|
|
383
|
+
LOG_LEVEL_ENUM = ["0", "1", "2", "3", "4"]
|
|
384
|
+
LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
|
|
385
|
+
class LogLevel:
|
|
386
|
+
class DEBUG:
|
|
387
|
+
value = 0
|
|
388
|
+
class INFO:
|
|
389
|
+
value = 1
|
|
390
|
+
class WARNING:
|
|
391
|
+
value = 2
|
|
392
|
+
class ERROR:
|
|
393
|
+
value = 3
|
|
335
394
|
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
336
395
|
|
|
396
|
+
NOT_CREATED_INSTANCE = "PrecisionDebugger instance is not created."
|
|
397
|
+
|
|
337
398
|
|
|
338
399
|
class GraphMode:
|
|
339
400
|
NPY_MODE = "NPY_MODE"
|
|
@@ -13,8 +13,8 @@ class MsprobeException(CodedException):
|
|
|
13
13
|
OVERFLOW_NUMS_ERROR = 1
|
|
14
14
|
|
|
15
15
|
err_strs = {
|
|
16
|
-
INVALID_PARAM_ERROR: "[msprobe] 无效参数:
|
|
17
|
-
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数
|
|
16
|
+
INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
|
|
17
|
+
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
|
|
18
18
|
}
|
|
19
19
|
|
|
20
20
|
|
|
@@ -22,6 +22,7 @@ import re
|
|
|
22
22
|
import shutil
|
|
23
23
|
import yaml
|
|
24
24
|
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
25
26
|
|
|
26
27
|
from msprobe.core.common.log import logger
|
|
27
28
|
from msprobe.core.common.exceptions import FileCheckException
|
|
@@ -187,7 +188,7 @@ def check_other_user_writable(path):
|
|
|
187
188
|
|
|
188
189
|
def check_path_owner_consistent(path):
|
|
189
190
|
file_owner = os.stat(path).st_uid
|
|
190
|
-
if file_owner != os.getuid():
|
|
191
|
+
if file_owner != os.getuid() and os.getuid() != 0:
|
|
191
192
|
logger.error('The file path %s may be insecure because is does not belong to you.' % path)
|
|
192
193
|
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
193
194
|
|
|
@@ -214,7 +215,9 @@ def check_common_file_size(file_path):
|
|
|
214
215
|
for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
|
|
215
216
|
if file_path.endswith(suffix):
|
|
216
217
|
check_file_size(file_path, max_size)
|
|
217
|
-
|
|
218
|
+
return
|
|
219
|
+
check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
|
|
220
|
+
|
|
218
221
|
|
|
219
222
|
|
|
220
223
|
def check_file_suffix(file_path, file_suffix):
|
|
@@ -322,7 +325,7 @@ def check_file_type(path):
|
|
|
322
325
|
elif os.path.isfile(path):
|
|
323
326
|
return FileCheckConst.FILE
|
|
324
327
|
else:
|
|
325
|
-
logger.error('
|
|
328
|
+
logger.error(f'{path} does not exist, please check!')
|
|
326
329
|
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
327
330
|
|
|
328
331
|
|
|
@@ -338,10 +341,10 @@ def load_yaml(yaml_path):
|
|
|
338
341
|
return yaml_data
|
|
339
342
|
|
|
340
343
|
|
|
341
|
-
def load_npy(filepath
|
|
344
|
+
def load_npy(filepath):
|
|
342
345
|
check_file_or_directory_path(filepath)
|
|
343
346
|
try:
|
|
344
|
-
npy = np.load(filepath
|
|
347
|
+
npy = np.load(filepath)
|
|
345
348
|
except Exception as e:
|
|
346
349
|
logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
|
|
347
350
|
raise RuntimeError(f"Load numpy file {filepath} failed.") from e
|
|
@@ -374,6 +377,20 @@ def save_json(json_path, data, indent=None):
|
|
|
374
377
|
change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
375
378
|
|
|
376
379
|
|
|
380
|
+
def save_yaml(yaml_path, data):
|
|
381
|
+
yaml_path = os.path.realpath(yaml_path)
|
|
382
|
+
check_path_before_create(yaml_path)
|
|
383
|
+
try:
|
|
384
|
+
with FileOpen(yaml_path, 'w') as f:
|
|
385
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
386
|
+
yaml.dump(data, f, sort_keys=False)
|
|
387
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
388
|
+
except Exception as e:
|
|
389
|
+
logger.error(f'Save yaml file "{os.path.basename(yaml_path)}" failed.')
|
|
390
|
+
raise RuntimeError(f"Save yaml file {yaml_path} failed.") from e
|
|
391
|
+
change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
392
|
+
|
|
393
|
+
|
|
377
394
|
def move_file(src_path, dst_path):
|
|
378
395
|
check_file_or_directory_path(src_path)
|
|
379
396
|
check_path_before_create(dst_path)
|
|
@@ -396,9 +413,9 @@ def save_npy(data, filepath):
|
|
|
396
413
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
397
414
|
|
|
398
415
|
|
|
399
|
-
def save_npy_to_txt(
|
|
416
|
+
def save_npy_to_txt(data, dst_file='', align=0):
|
|
400
417
|
if os.path.exists(dst_file):
|
|
401
|
-
|
|
418
|
+
logger.info("Dst file %s exists, will not save new one." % dst_file)
|
|
402
419
|
return
|
|
403
420
|
shape = data.shape
|
|
404
421
|
data = data.flatten()
|
|
@@ -411,7 +428,7 @@ def save_npy_to_txt(self, data, dst_file='', align=0):
|
|
|
411
428
|
try:
|
|
412
429
|
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
413
430
|
except Exception as e:
|
|
414
|
-
|
|
431
|
+
logger.error("An unexpected error occurred: %s when savetxt to %s" % (str(e), dst_file))
|
|
415
432
|
change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
416
433
|
|
|
417
434
|
|
|
@@ -431,7 +448,25 @@ def save_workbook(workbook, file_path):
|
|
|
431
448
|
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
432
449
|
|
|
433
450
|
|
|
434
|
-
def write_csv(data, filepath, mode="a+"):
|
|
451
|
+
def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
452
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
453
|
+
if not isinstance(value, str):
|
|
454
|
+
return True
|
|
455
|
+
try:
|
|
456
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
457
|
+
float(value)
|
|
458
|
+
except ValueError:
|
|
459
|
+
# otherwise, they will be considered as formular injections
|
|
460
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
461
|
+
return True
|
|
462
|
+
|
|
463
|
+
if malicious_check:
|
|
464
|
+
for row in data:
|
|
465
|
+
for cell in row:
|
|
466
|
+
if not csv_value_is_valid(cell):
|
|
467
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed " \
|
|
468
|
+
f"to be written into the csv: {filepath}.")
|
|
469
|
+
|
|
435
470
|
file_path = os.path.realpath(filepath)
|
|
436
471
|
check_path_before_create(filepath)
|
|
437
472
|
try:
|
|
@@ -444,6 +479,16 @@ def write_csv(data, filepath, mode="a+"):
|
|
|
444
479
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
445
480
|
|
|
446
481
|
|
|
482
|
+
def read_csv(filepath):
|
|
483
|
+
check_file_or_directory_path(filepath)
|
|
484
|
+
try:
|
|
485
|
+
csv_data = pd.read_csv(filepath)
|
|
486
|
+
except Exception as e:
|
|
487
|
+
logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
|
|
488
|
+
raise RuntimeError(f"Read csv file {filepath} failed.") from e
|
|
489
|
+
return csv_data
|
|
490
|
+
|
|
491
|
+
|
|
447
492
|
def remove_path(path):
|
|
448
493
|
if not os.path.exists(path):
|
|
449
494
|
return
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class InplaceOpChecker:
|
|
6
|
+
OP_FUNCTIONAL = 'functional'
|
|
7
|
+
OP_TENSOR = 'tensor'
|
|
8
|
+
OP_TORCH = 'torch'
|
|
9
|
+
OP_DISTRIBUTED = 'distributed'
|
|
10
|
+
|
|
11
|
+
INPLACE_OPS_DICT = None
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def load_ops(cls):
|
|
15
|
+
if cls.INPLACE_OPS_DICT is None:
|
|
16
|
+
cls.INPLACE_OPS_DICT = dict()
|
|
17
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
18
|
+
yaml_path = os.path.join(cur_path, "inplace_ops.yaml")
|
|
19
|
+
all_ops = load_yaml(yaml_path)
|
|
20
|
+
cls.INPLACE_OPS_DICT[cls.OP_FUNCTIONAL] = all_ops.get('inplace_functional_op')
|
|
21
|
+
cls.INPLACE_OPS_DICT[cls.OP_TENSOR] = all_ops.get('inplace_tensor_op')
|
|
22
|
+
cls.INPLACE_OPS_DICT[cls.OP_TORCH] = all_ops.get('inplace_torch_op')
|
|
23
|
+
cls.INPLACE_OPS_DICT[cls.OP_DISTRIBUTED] = all_ops.get('inplace_distributed_op')
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def check(cls, api, category='distributed'):
|
|
27
|
+
"""
|
|
28
|
+
给定api和分类,检查其是否为inplace操作
|
|
29
|
+
"""
|
|
30
|
+
if not cls.INPLACE_OPS_DICT:
|
|
31
|
+
cls.load_ops()
|
|
32
|
+
|
|
33
|
+
if category not in cls.INPLACE_OPS_DICT.keys():
|
|
34
|
+
return False
|
|
35
|
+
return api in cls.INPLACE_OPS_DICT[category]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
InplaceOpChecker.load_ops()
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
inplace_functional_op:
|
|
2
|
+
- threshold_
|
|
3
|
+
- relu_
|
|
4
|
+
- hardtanh_
|
|
5
|
+
- elu_
|
|
6
|
+
- selu_
|
|
7
|
+
- celu_
|
|
8
|
+
- leaky_relu_
|
|
9
|
+
- rrelu_
|
|
10
|
+
|
|
11
|
+
inplace_tensor_op:
|
|
12
|
+
- __iadd__
|
|
13
|
+
- __iand__
|
|
14
|
+
- __idiv__
|
|
15
|
+
- __ifloordiv__
|
|
16
|
+
- __ilshift__
|
|
17
|
+
- __imod__
|
|
18
|
+
- __imul__
|
|
19
|
+
- __ior__
|
|
20
|
+
- __irshift__
|
|
21
|
+
- __isub__
|
|
22
|
+
- __ixor__
|
|
23
|
+
- abs_
|
|
24
|
+
- absolute_
|
|
25
|
+
- acos_
|
|
26
|
+
- acosh_
|
|
27
|
+
- add_
|
|
28
|
+
- addbmm_
|
|
29
|
+
- addcdiv_
|
|
30
|
+
- addcmul_
|
|
31
|
+
- addmm_
|
|
32
|
+
- addmv_
|
|
33
|
+
- addr_
|
|
34
|
+
- arccos_
|
|
35
|
+
- arccosh_
|
|
36
|
+
- arcsin_
|
|
37
|
+
- arcsinh_
|
|
38
|
+
- arctan_
|
|
39
|
+
- arctanh_
|
|
40
|
+
- asin_
|
|
41
|
+
- asinh_
|
|
42
|
+
- atan2_
|
|
43
|
+
- atan_
|
|
44
|
+
- atanh_
|
|
45
|
+
- baddbmm_
|
|
46
|
+
- bernoulli_
|
|
47
|
+
- bitwise_and_
|
|
48
|
+
- bitwise_not_
|
|
49
|
+
- bitwise_or_
|
|
50
|
+
- bitwise_xor_
|
|
51
|
+
- cauchy_
|
|
52
|
+
- ceil_
|
|
53
|
+
- clamp_
|
|
54
|
+
- clamp_max_
|
|
55
|
+
- clamp_min_
|
|
56
|
+
- clip_
|
|
57
|
+
- copysign_
|
|
58
|
+
- cos_
|
|
59
|
+
- cosh_
|
|
60
|
+
- cumprod_
|
|
61
|
+
- cumsum_
|
|
62
|
+
- deg2rad_
|
|
63
|
+
- digamma_
|
|
64
|
+
- div_
|
|
65
|
+
- divide_
|
|
66
|
+
- eq_
|
|
67
|
+
- erf_
|
|
68
|
+
- erfc_
|
|
69
|
+
- erfinv_
|
|
70
|
+
- exp2_
|
|
71
|
+
- exp_
|
|
72
|
+
- expm1_
|
|
73
|
+
- exponential_
|
|
74
|
+
- fill_
|
|
75
|
+
- fill_diagonal_
|
|
76
|
+
- fix_
|
|
77
|
+
- float_power_
|
|
78
|
+
- floor_
|
|
79
|
+
- floor_divide_
|
|
80
|
+
- fmod_
|
|
81
|
+
- frac_
|
|
82
|
+
- gcd_
|
|
83
|
+
- ge_
|
|
84
|
+
- geometric_
|
|
85
|
+
- greater_
|
|
86
|
+
- gt_
|
|
87
|
+
- greater_equal_
|
|
88
|
+
- heaviside_
|
|
89
|
+
- hypot_
|
|
90
|
+
- igamma_
|
|
91
|
+
- igammac_
|
|
92
|
+
- index_add_
|
|
93
|
+
- index_copy_
|
|
94
|
+
- index_fill_
|
|
95
|
+
- index_put_
|
|
96
|
+
- lcm_
|
|
97
|
+
- ldexp_
|
|
98
|
+
- le_
|
|
99
|
+
- lerp_
|
|
100
|
+
- less_
|
|
101
|
+
- less_equal_
|
|
102
|
+
- lgamma_
|
|
103
|
+
- log10_
|
|
104
|
+
- log1p_
|
|
105
|
+
- log2_
|
|
106
|
+
- log_
|
|
107
|
+
- log_normal_
|
|
108
|
+
- logical_and_
|
|
109
|
+
- logical_not_
|
|
110
|
+
- logical_or_
|
|
111
|
+
- logical_xor_
|
|
112
|
+
- logit_
|
|
113
|
+
- lt_
|
|
114
|
+
- map2_
|
|
115
|
+
- map_
|
|
116
|
+
- masked_fill_
|
|
117
|
+
- masked_scatter_
|
|
118
|
+
- mul_
|
|
119
|
+
- multiply_
|
|
120
|
+
- mvlgamma_
|
|
121
|
+
- ne_
|
|
122
|
+
- neg_
|
|
123
|
+
- negative_
|
|
124
|
+
- normal_
|
|
125
|
+
- not_equal_
|
|
126
|
+
- pow_
|
|
127
|
+
- polygamma_
|
|
128
|
+
- put_
|
|
129
|
+
- rad2deg_
|
|
130
|
+
- reciprocal_
|
|
131
|
+
- relu_
|
|
132
|
+
- remainder_
|
|
133
|
+
- renorm_
|
|
134
|
+
- resize_
|
|
135
|
+
- resize_as_
|
|
136
|
+
- round_
|
|
137
|
+
- rsqrt_
|
|
138
|
+
- scatter_
|
|
139
|
+
- scatter_add_
|
|
140
|
+
- sgn_
|
|
141
|
+
- sigmoid_
|
|
142
|
+
- sign_
|
|
143
|
+
- sin_
|
|
144
|
+
- sinc_
|
|
145
|
+
- sinh_
|
|
146
|
+
- sqrt_
|
|
147
|
+
- square_
|
|
148
|
+
- squeeze_
|
|
149
|
+
- sub_
|
|
150
|
+
- t_
|
|
151
|
+
- tan_
|
|
152
|
+
- tanh_
|
|
153
|
+
- transpose_
|
|
154
|
+
- tril_
|
|
155
|
+
- triu_
|
|
156
|
+
- true_divide_
|
|
157
|
+
- trunc_
|
|
158
|
+
- unsqueeze_
|
|
159
|
+
- xlogy_
|
|
160
|
+
|
|
161
|
+
inplace_torch_op:
|
|
162
|
+
- _add_relu_
|
|
163
|
+
- abs_
|
|
164
|
+
- acos_
|
|
165
|
+
- acosh_
|
|
166
|
+
- addmv_
|
|
167
|
+
- alpha_dropout_
|
|
168
|
+
- arccos_
|
|
169
|
+
- arccosh_
|
|
170
|
+
- arcsin_
|
|
171
|
+
- arcsinh_
|
|
172
|
+
- arctan_
|
|
173
|
+
- arctanh_
|
|
174
|
+
- asin_
|
|
175
|
+
- asinh_
|
|
176
|
+
- atan_
|
|
177
|
+
- atanh_
|
|
178
|
+
- ceil_
|
|
179
|
+
- celu_
|
|
180
|
+
- clamp_
|
|
181
|
+
- clamp_max_
|
|
182
|
+
- clamp_min_
|
|
183
|
+
- clip_
|
|
184
|
+
- cos_
|
|
185
|
+
- cosh_
|
|
186
|
+
- deg2rad_
|
|
187
|
+
- dropout_
|
|
188
|
+
- embedding_renorm_
|
|
189
|
+
- erf_
|
|
190
|
+
- erfc_
|
|
191
|
+
- exp2_
|
|
192
|
+
- exp_
|
|
193
|
+
- expm1_
|
|
194
|
+
- feature_alpha_dropout_
|
|
195
|
+
- feature_dropout_
|
|
196
|
+
- fill_
|
|
197
|
+
- fix_
|
|
198
|
+
- floor_
|
|
199
|
+
- frac_
|
|
200
|
+
- gcd_
|
|
201
|
+
- index_put_
|
|
202
|
+
- lcm_
|
|
203
|
+
- ldexp_
|
|
204
|
+
- log10_
|
|
205
|
+
- log1p_
|
|
206
|
+
- log2_
|
|
207
|
+
- log_
|
|
208
|
+
- logit_
|
|
209
|
+
- nan_to_num_
|
|
210
|
+
- neg_
|
|
211
|
+
- negative_
|
|
212
|
+
- rad2deg_
|
|
213
|
+
- reciprocal_
|
|
214
|
+
- relu_
|
|
215
|
+
- resize_as_
|
|
216
|
+
- round_
|
|
217
|
+
- rrelu_
|
|
218
|
+
- rsqrt_
|
|
219
|
+
- selu_
|
|
220
|
+
- sigmoid_
|
|
221
|
+
- sin_
|
|
222
|
+
- sinc_
|
|
223
|
+
- sinh_
|
|
224
|
+
- sqrt_
|
|
225
|
+
- square_
|
|
226
|
+
- tan_
|
|
227
|
+
- tanh_
|
|
228
|
+
- threshold_
|
|
229
|
+
- trunc_
|
|
230
|
+
- xlogy_
|
|
231
|
+
|
|
232
|
+
inplace_distributed_op:
|
|
233
|
+
- broadcast
|
|
234
|
+
- all_reduce
|
|
235
|
+
- reduce
|
|
236
|
+
- all_gather
|
|
237
|
+
- gather
|
|
238
|
+
- scatter
|
|
239
|
+
- reduce_scatter
|
|
240
|
+
- _reduce_scatter_base
|
|
241
|
+
- _all_gather_base
|
|
242
|
+
- send
|
|
243
|
+
- recv
|
|
244
|
+
- irecv
|
|
245
|
+
- isend
|
|
246
|
+
- all_to_all_single
|
|
247
|
+
- all_to_all
|
|
248
|
+
- all_gather_into_tensor
|
|
249
|
+
- reduce_scatter_tensor
|
|
250
|
+
|
|
251
|
+
|
msprobe/core/common/log.py
CHANGED
|
@@ -4,13 +4,19 @@ import sys
|
|
|
4
4
|
from functools import wraps
|
|
5
5
|
from msprobe.core.common.const import MsgConst
|
|
6
6
|
|
|
7
|
-
MSPROBE_LOG_LEVEL = os.environ.get(MsgConst.MSPROBE_LOG_LEVEL, "")
|
|
8
|
-
|
|
9
7
|
|
|
10
8
|
class BaseLogger:
|
|
11
9
|
def __init__(self):
|
|
12
10
|
self.rank = None
|
|
11
|
+
self.level = self.get_level()
|
|
13
12
|
|
|
13
|
+
@staticmethod
|
|
14
|
+
def get_level():
|
|
15
|
+
input_level = os.environ.get(MsgConst.MSPROBE_LOG_LEVEL)
|
|
16
|
+
if input_level not in MsgConst.LOG_LEVEL_ENUM:
|
|
17
|
+
return MsgConst.LogLevel.INFO.value
|
|
18
|
+
else:
|
|
19
|
+
return int(input_level)
|
|
14
20
|
|
|
15
21
|
def get_rank(self):
|
|
16
22
|
return self.rank
|
|
@@ -22,23 +28,26 @@ class BaseLogger:
|
|
|
22
28
|
msg = msg.replace(char, '_')
|
|
23
29
|
return func(self, msg, **kwargs)
|
|
24
30
|
return func_level
|
|
25
|
-
|
|
26
|
-
@filter_special_chars
|
|
27
|
-
def info(self, msg, **kwargs):
|
|
28
|
-
self._print_log(MsgConst.LEVEL[0], msg, **kwargs)
|
|
29
|
-
|
|
31
|
+
|
|
30
32
|
@filter_special_chars
|
|
31
33
|
def error(self, msg):
|
|
32
|
-
self.
|
|
34
|
+
if self.level <= MsgConst.LogLevel.ERROR.value:
|
|
35
|
+
self._print_log(MsgConst.LOG_LEVEL[3], msg)
|
|
33
36
|
|
|
34
37
|
@filter_special_chars
|
|
35
38
|
def warning(self, msg):
|
|
36
|
-
self.
|
|
39
|
+
if self.level <= MsgConst.LogLevel.WARNING.value:
|
|
40
|
+
self._print_log(MsgConst.LOG_LEVEL[2], msg)
|
|
41
|
+
|
|
42
|
+
@filter_special_chars
|
|
43
|
+
def info(self, msg):
|
|
44
|
+
if self.level <= MsgConst.LogLevel.INFO.value:
|
|
45
|
+
self._print_log(MsgConst.LOG_LEVEL[1], msg)
|
|
37
46
|
|
|
38
47
|
@filter_special_chars
|
|
39
48
|
def debug(self, msg):
|
|
40
|
-
if
|
|
41
|
-
self._print_log(MsgConst.
|
|
49
|
+
if self.level <= MsgConst.LogLevel.DEBUG.value:
|
|
50
|
+
self._print_log(MsgConst.LOG_LEVEL[0], msg)
|
|
42
51
|
|
|
43
52
|
def on_rank_0(self, func):
|
|
44
53
|
def func_rank_0(*args, **kwargs):
|
|
@@ -73,4 +82,5 @@ class BaseLogger:
|
|
|
73
82
|
print(full_msg, end=end)
|
|
74
83
|
sys.stdout.flush()
|
|
75
84
|
|
|
85
|
+
|
|
76
86
|
logger = BaseLogger()
|