mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -22,7 +22,7 @@ import re
|
|
|
22
22
|
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
|
-
from msprobe.core.common.const import MonitorConst
|
|
25
|
+
from msprobe.core.common.const import MonitorConst
|
|
26
26
|
from msprobe.pytorch.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import is_int
|
|
28
28
|
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
@@ -43,7 +43,6 @@ DIRECTORY_MAX_LENGTH = 4096
|
|
|
43
43
|
|
|
44
44
|
beijing_tz = timezone(timedelta(hours=8))
|
|
45
45
|
MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
|
|
46
|
-
MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad"))
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
class MsgConst:
|
|
@@ -102,6 +101,11 @@ def validate_ops(ops):
|
|
|
102
101
|
default_op = MonitorConst.OP_LIST[0]
|
|
103
102
|
valid_ops.append(default_op)
|
|
104
103
|
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
104
|
+
# 增加默认shape和dtype参数
|
|
105
|
+
if "shape" not in valid_ops:
|
|
106
|
+
valid_ops.append("shape")
|
|
107
|
+
if "dtype" not in valid_ops:
|
|
108
|
+
valid_ops.append("dtype")
|
|
105
109
|
return valid_ops
|
|
106
110
|
|
|
107
111
|
|
|
@@ -199,7 +203,7 @@ def validate_alert(alert):
|
|
|
199
203
|
args = rule.get("args")
|
|
200
204
|
if args and isinstance(args, dict):
|
|
201
205
|
threshold = args.get("threshold")
|
|
202
|
-
if not isinstance(threshold, float) or threshold < 0:
|
|
206
|
+
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
203
207
|
raise TypeError('threshold must be float and not less than 0')
|
|
204
208
|
dump = alert.get('dump')
|
|
205
209
|
if dump and not isinstance(dump, bool):
|
|
@@ -220,6 +224,13 @@ def validate_dynamic_on(dynamic_on):
|
|
|
220
224
|
raise TypeError('dynamic_on should be a bool')
|
|
221
225
|
|
|
222
226
|
|
|
227
|
+
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
228
|
+
if not isinstance(monitor_mbs_grad, bool):
|
|
229
|
+
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
230
|
+
return False
|
|
231
|
+
return monitor_mbs_grad
|
|
232
|
+
|
|
233
|
+
|
|
223
234
|
def validate_config(config):
|
|
224
235
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
225
236
|
|
|
@@ -274,6 +285,8 @@ def validate_config(config):
|
|
|
274
285
|
squash_name = config.get('squash_name', True)
|
|
275
286
|
validate_squash_name(squash_name)
|
|
276
287
|
|
|
288
|
+
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
289
|
+
|
|
277
290
|
dynamic_on = config.get('dynamic_on', False)
|
|
278
291
|
validate_dynamic_on(dynamic_on)
|
|
279
292
|
|
|
@@ -208,8 +208,10 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
208
208
|
dispatch_workflow(run_param, data_info)
|
|
209
209
|
else:
|
|
210
210
|
self.lock.acquire()
|
|
211
|
-
|
|
212
|
-
|
|
211
|
+
try:
|
|
212
|
+
self.all_summary.append([])
|
|
213
|
+
finally:
|
|
214
|
+
self.lock.release()
|
|
213
215
|
run_param.process_flag = True
|
|
214
216
|
if self.check_fun(func, run_param):
|
|
215
217
|
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
@@ -110,8 +110,11 @@ def dump_data(data, prefix, dump_path):
|
|
|
110
110
|
def save_temp_summary(api_index, single_api_summary, path, lock):
|
|
111
111
|
summary_path = os.path.join(path, f'summary.json')
|
|
112
112
|
lock.acquire()
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
try:
|
|
114
|
+
data = [api_index, single_api_summary]
|
|
115
|
+
save_json(summary_path, data, mode='a')
|
|
116
|
+
finally:
|
|
117
|
+
lock.release()
|
|
115
118
|
|
|
116
119
|
|
|
117
120
|
def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
|
|
@@ -13,12 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import hashlib
|
|
17
16
|
import os
|
|
18
17
|
import re
|
|
19
18
|
import subprocess
|
|
20
19
|
import sys
|
|
21
20
|
import time
|
|
21
|
+
import zlib
|
|
22
22
|
from collections import namedtuple
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
@@ -114,8 +114,8 @@ class Util:
|
|
|
114
114
|
@staticmethod
|
|
115
115
|
def get_md5_for_numpy(obj):
|
|
116
116
|
np_bytes = obj.tobytes()
|
|
117
|
-
|
|
118
|
-
return
|
|
117
|
+
md5_crc = zlib.crc32(np_bytes)
|
|
118
|
+
return f"{md5_crc:08x}"
|
|
119
119
|
|
|
120
120
|
@staticmethod
|
|
121
121
|
def deal_with_dir_or_file_inconsistency(output_path):
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -18,8 +18,7 @@ import re
|
|
|
18
18
|
|
|
19
19
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path,
|
|
22
|
-
FileChecker
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, FileChecker
|
|
23
22
|
from msprobe.core.common.log import logger
|
|
24
23
|
from msprobe.core.common.utils import is_int
|
|
25
24
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
@@ -43,6 +42,7 @@ class TensorConfig(BaseConfig):
|
|
|
43
42
|
self.tls_path = json_config.get("tls_path", "./")
|
|
44
43
|
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
45
44
|
self.check_config()
|
|
45
|
+
self._check_summary_mode()
|
|
46
46
|
self._check_file_format()
|
|
47
47
|
if self.online_run_ut:
|
|
48
48
|
self._check_online_run_ut()
|
|
@@ -66,8 +66,10 @@ class TensorConfig(BaseConfig):
|
|
|
66
66
|
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
67
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
68
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
+
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
+
if os.path.exists(crl_path):
|
|
72
|
+
check_file_or_directory_path(crl_path)
|
|
71
73
|
|
|
72
74
|
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
73
75
|
raise Exception(f"host: {self.host} is invalid.")
|
|
@@ -82,9 +84,8 @@ class StatisticsConfig(BaseConfig):
|
|
|
82
84
|
self.check_config()
|
|
83
85
|
self._check_summary_mode()
|
|
84
86
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
raise Exception("summary_mode is invalid")
|
|
87
|
+
self.tensor_list = json_config.get("tensor_list", [])
|
|
88
|
+
self._check_str_list_config(self.tensor_list, "tensor_list")
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
class OverflowCheckConfig(BaseConfig):
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.utils import Const
|
|
17
|
+
from msprobe.core.service import BaseService
|
|
18
|
+
from msprobe.pytorch.attl_manager import ATTLManager
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
|
|
21
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
22
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate
|
|
23
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
|
+
from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func
|
|
25
|
+
from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
26
|
+
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
27
|
+
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PytorchService(BaseService):
|
|
33
|
+
@property
|
|
34
|
+
def _get_framework_type(self):
|
|
35
|
+
return Const.PT_FRAMEWORK
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _get_current_rank():
|
|
39
|
+
return get_rank_if_initialized()
|
|
40
|
+
|
|
41
|
+
def reset_status(self):
|
|
42
|
+
self._reset_status()
|
|
43
|
+
|
|
44
|
+
def _init_specific_components(self):
|
|
45
|
+
self.logger = logger
|
|
46
|
+
self.api_register = get_api_register()
|
|
47
|
+
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
48
|
+
self.attl_manager = ATTLManager(self.config)
|
|
49
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
50
|
+
self.api_template = ApiTemplate
|
|
51
|
+
|
|
52
|
+
def _register_hook(self):
|
|
53
|
+
self.attl_manager.attl_init()
|
|
54
|
+
if self._is_mix_level:
|
|
55
|
+
register_optimizer_hook(self.data_collector)
|
|
56
|
+
|
|
57
|
+
def _register_api_hook(self):
|
|
58
|
+
super()._register_api_hook()
|
|
59
|
+
wrap_jit_script_func()
|
|
60
|
+
|
|
61
|
+
def _register_module_hook(self):
|
|
62
|
+
ModuleProcesser.enable_module_dump = True
|
|
63
|
+
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
64
|
+
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
65
|
+
|
|
66
|
+
def _run_ut_dispatch(self, status):
|
|
67
|
+
if torch_version_above_or_equal_2:
|
|
68
|
+
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
69
|
+
|
|
70
|
+
def _reset_status(self):
|
|
71
|
+
super()._reset_status()
|
|
72
|
+
ModuleProcesser.reset_module_stats()
|
|
73
|
+
HOOKModule.reset_module_stats()
|
|
@@ -14,9 +14,11 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
from dataclasses import dataclass
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import Const
|
|
19
20
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
21
|
+
from msprobe.core.common.utils import load_stack_json
|
|
20
22
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
23
|
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
22
24
|
from msprobe.visualization.graph.graph import Graph
|
|
@@ -44,7 +46,7 @@ class GraphBuilder:
|
|
|
44
46
|
"""
|
|
45
47
|
construct_dict = load_json(construct_path)
|
|
46
48
|
dump_dict = load_json(data_path)
|
|
47
|
-
stack_dict =
|
|
49
|
+
stack_dict = load_stack_json(stack_path)
|
|
48
50
|
if not complete_stack:
|
|
49
51
|
GraphBuilder._simplify_stack(stack_dict)
|
|
50
52
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
@@ -61,10 +63,10 @@ class GraphBuilder:
|
|
|
61
63
|
"""
|
|
62
64
|
result = {}
|
|
63
65
|
if config.graph_b:
|
|
64
|
-
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
|
|
65
|
-
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
|
|
66
|
+
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
|
|
67
|
+
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
|
|
66
68
|
else:
|
|
67
|
-
result = config.graph_n.to_dict()
|
|
69
|
+
result = config.graph_n.to_dict(config.compare_mode)
|
|
68
70
|
if config.tool_tip:
|
|
69
71
|
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
70
72
|
if config.node_colors:
|
|
@@ -187,6 +189,8 @@ class GraphBuilder:
|
|
|
187
189
|
# 数据格式:"output": [[{param1}, {param2}, ...]]
|
|
188
190
|
if GraphBuilder._is_valid_batch_p2p_output(param_list):
|
|
189
191
|
for param in param_list[0]:
|
|
192
|
+
if not isinstance(param, dict):
|
|
193
|
+
continue
|
|
190
194
|
info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
|
|
191
195
|
GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
|
|
192
196
|
node.batch_p2p_info.append(info)
|
|
@@ -254,14 +258,12 @@ class GraphBuilder:
|
|
|
254
258
|
max_info = {prefix: 0 for prefix in prefixes}
|
|
255
259
|
|
|
256
260
|
for key in graph.node_map.keys():
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if
|
|
262
|
-
|
|
263
|
-
if num > max_info[prefix]:
|
|
264
|
-
max_info[prefix] = num
|
|
261
|
+
parts = key.split(Const.SEP)
|
|
262
|
+
if len(parts) > 2 and parts[-2] == Const.BACKWARD:
|
|
263
|
+
num = int(parts[-1])
|
|
264
|
+
prefix = Const.SEP.join(parts[:-2])
|
|
265
|
+
if prefix in max_info and num > max_info[prefix]:
|
|
266
|
+
max_info[prefix] = num
|
|
265
267
|
|
|
266
268
|
for prefix, num in max_info.items():
|
|
267
269
|
node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
|
|
@@ -277,7 +279,7 @@ class GraphBuilder:
|
|
|
277
279
|
|
|
278
280
|
class GraphExportConfig:
|
|
279
281
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
280
|
-
overflow_check=False):
|
|
282
|
+
overflow_check=False, compare_mode=None):
|
|
281
283
|
self.graph_n = graph_n
|
|
282
284
|
self.graph_b = graph_b
|
|
283
285
|
self.tool_tip = tool_tip
|
|
@@ -285,3 +287,21 @@ class GraphExportConfig:
|
|
|
285
287
|
self.micro_steps = micro_steps
|
|
286
288
|
self.task = task
|
|
287
289
|
self.overflow_check = overflow_check
|
|
290
|
+
self.compare_mode = compare_mode
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@dataclass
|
|
294
|
+
class GraphInfo:
|
|
295
|
+
graph: Graph
|
|
296
|
+
construct_path: str
|
|
297
|
+
data_path: str
|
|
298
|
+
stack_path: str
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@dataclass
|
|
302
|
+
class BuildGraphTaskInfo:
|
|
303
|
+
graph_info_n: GraphInfo
|
|
304
|
+
graph_info_b: GraphInfo
|
|
305
|
+
npu_rank: str
|
|
306
|
+
bench_rank: str
|
|
307
|
+
time_str: str
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,12 +12,16 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
|
|
15
16
|
import re
|
|
16
|
-
|
|
17
|
+
|
|
18
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
19
|
+
from msprobe.core.compare.multiprocessing_compute import CompareRealData
|
|
20
|
+
from msprobe.core.compare.utils import read_op, merge_tensor, get_accuracy, make_result_table
|
|
17
21
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
18
22
|
from msprobe.visualization.utils import GraphConst
|
|
19
23
|
from msprobe.core.common.const import Const
|
|
20
|
-
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
# 用于将节点名字解析成对应的NodeOp的规则
|
|
23
27
|
op_patterns = [
|
|
@@ -53,13 +57,11 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
53
57
|
mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
|
|
54
58
|
|
|
55
59
|
if framework == Const.PT_FRAMEWORK:
|
|
56
|
-
from msprobe.pytorch.compare.pt_compare import
|
|
57
|
-
return
|
|
60
|
+
from msprobe.pytorch.compare.pt_compare import read_real_data
|
|
61
|
+
return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
|
|
58
62
|
else:
|
|
59
|
-
from msprobe.mindspore.compare.ms_compare import
|
|
60
|
-
|
|
61
|
-
ms_comparator.cross_frame = is_cross_frame
|
|
62
|
-
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
63
|
+
from msprobe.mindspore.compare.ms_compare import read_real_data
|
|
64
|
+
return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
def get_input_output(node_data, node_id):
|
|
@@ -119,11 +121,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
|
119
121
|
return True
|
|
120
122
|
|
|
121
123
|
|
|
122
|
-
def format_node_data(data_dict, node_id=None):
|
|
124
|
+
def format_node_data(data_dict, node_id=None, compare_mode=None):
|
|
123
125
|
"""
|
|
124
126
|
删除节点数据中不需要展示的字段
|
|
125
127
|
"""
|
|
126
128
|
del_list = ['requires_grad', 'full_op_name']
|
|
129
|
+
if GraphConst.MD5_COMPARE != compare_mode:
|
|
130
|
+
del_list.append(Const.MD5)
|
|
127
131
|
if node_id and GraphConst.BATCH_P2P in node_id:
|
|
128
132
|
del_list.extend(['op', 'peer', 'tag', 'group_id'])
|
|
129
133
|
for _, value in data_dict.items():
|
|
@@ -171,7 +175,7 @@ def _format_decimal_string(s):
|
|
|
171
175
|
"""
|
|
172
176
|
使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
|
|
173
177
|
"""
|
|
174
|
-
pattern = re.compile(r'
|
|
178
|
+
pattern = re.compile(r'^\d{1,20}\.\d{1,20}%?$')
|
|
175
179
|
matches = pattern.findall(s)
|
|
176
180
|
for match in matches:
|
|
177
181
|
is_percent = match.endswith('%')
|
|
@@ -226,3 +230,12 @@ def _format_data(data_dict):
|
|
|
226
230
|
if all_null:
|
|
227
231
|
data_dict.clear()
|
|
228
232
|
data_dict[GraphConst.VALUE] = GraphConst.NULL
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
236
|
+
"""
|
|
237
|
+
调用acc接口写入csv
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
241
|
+
return make_result_table(csv_data, dump_mode, stack_mode)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
-
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
18
|
-
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file
|
|
17
|
+
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df
|
|
18
|
+
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file
|
|
19
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
20
20
|
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
@@ -25,14 +25,16 @@ from msprobe.core.common.decorator import recursion_depth_decorator
|
|
|
25
25
|
class GraphComparator:
|
|
26
26
|
MAX_DEPTH = 1000
|
|
27
27
|
|
|
28
|
-
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
28
|
+
def __init__(self, graphs, dump_path_param, args, is_cross_framework, mapping_dict=None):
|
|
29
29
|
self.graph_n = graphs[0]
|
|
30
30
|
self.graph_b = graphs[1]
|
|
31
31
|
self._parse_param(dump_path_param, args.output_path)
|
|
32
32
|
self.framework = args.framework
|
|
33
|
+
self.layer_mapping = args.layer_mapping
|
|
33
34
|
self.mapping_dict = mapping_dict
|
|
34
35
|
self.fuzzy_match = args.fuzzy_match
|
|
35
36
|
self.pattern = re.compile(r'\.\d+\.')
|
|
37
|
+
self.is_cross_framework = is_cross_framework
|
|
36
38
|
|
|
37
39
|
def compare(self):
|
|
38
40
|
"""
|
|
@@ -69,50 +71,56 @@ class GraphComparator:
|
|
|
69
71
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
70
72
|
node.data.update(other_dict)
|
|
71
73
|
|
|
72
|
-
|
|
73
|
-
def _compare_nodes(self, node_n):
|
|
74
|
+
def _compare_nodes(self, node_root):
|
|
74
75
|
"""
|
|
75
|
-
|
|
76
|
+
遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
76
77
|
这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
|
|
77
78
|
"""
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
92
|
-
self._get_and_add_result(node_n, node_b)
|
|
93
|
-
for subnode in node_n.subnodes:
|
|
94
|
-
self._compare_nodes(subnode)
|
|
95
|
-
|
|
96
|
-
@recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH)
|
|
97
|
-
def _compare_nodes_fuzzy(self, node_n):
|
|
98
|
-
if node_n.op != NodeOp.function_api:
|
|
99
|
-
# 模块经过模糊匹配
|
|
100
|
-
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
79
|
+
def compare_single_node(node_n):
|
|
80
|
+
if self.layer_mapping:
|
|
81
|
+
node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
|
|
82
|
+
if node_b:
|
|
83
|
+
ancestors_n.append(node_n.id)
|
|
84
|
+
ancestors_b.append(node_b.id)
|
|
85
|
+
node_n.matched_node_link = ancestors_b
|
|
86
|
+
node_b.matched_node_link = ancestors_n
|
|
87
|
+
else:
|
|
88
|
+
node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
|
|
89
|
+
if node_b:
|
|
90
|
+
ancestors.append(node_b.id)
|
|
91
|
+
node_n.add_link(node_b, ancestors)
|
|
101
92
|
if node_b:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
93
|
+
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
94
|
+
self._get_and_add_result(node_n, node_b)
|
|
95
|
+
node_list.extend(node_n.subnodes)
|
|
96
|
+
|
|
97
|
+
node_list = [node_root]
|
|
98
|
+
while node_list:
|
|
99
|
+
compare_single_node(node_list.pop(0))
|
|
100
|
+
|
|
101
|
+
def _compare_nodes_fuzzy(self, node_root):
|
|
102
|
+
def compare_single_nodes_fuzzy(node_n):
|
|
103
|
+
if node_n.op != NodeOp.function_api:
|
|
104
|
+
# 模块经过模糊匹配
|
|
105
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
106
|
+
if node_b:
|
|
107
|
+
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
108
|
+
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
109
|
+
recount_result_n = self._recount_api_node(node_n)
|
|
110
|
+
recount_result_b = self._recount_api_node(node_b)
|
|
111
|
+
for recount_node_id, node_id_n in recount_result_n.items():
|
|
112
|
+
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
113
|
+
if not api_node_n:
|
|
114
|
+
continue
|
|
115
|
+
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
116
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
117
|
+
if api_node_b:
|
|
118
|
+
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
119
|
+
node_list.extend(node_n.subnodes)
|
|
120
|
+
|
|
121
|
+
node_list = [node_root]
|
|
122
|
+
while node_list:
|
|
123
|
+
compare_single_nodes_fuzzy(node_list.pop(0))
|
|
116
124
|
|
|
117
125
|
def _parse_param(self, dump_path_param, output_path):
|
|
118
126
|
self.dump_path_param = dump_path_param
|
|
@@ -128,7 +136,7 @@ class GraphComparator:
|
|
|
128
136
|
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
129
137
|
return
|
|
130
138
|
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
131
|
-
df = run_real_data(self.dump_path_param, df, self.framework,
|
|
139
|
+
df = run_real_data(self.dump_path_param, df, self.framework, self.is_cross_framework)
|
|
132
140
|
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
133
141
|
for node in self.ma.compare_nodes:
|
|
134
142
|
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import math
|
|
16
17
|
import json
|
|
17
18
|
from msprobe.core.common.const import CompareConst, Const
|
|
18
19
|
from msprobe.visualization.utils import ToolTip, GraphConst, str2float
|
|
@@ -24,6 +25,12 @@ class ModeAdapter:
|
|
|
24
25
|
self.csv_data = []
|
|
25
26
|
self.compare_nodes = []
|
|
26
27
|
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _is_invalid(value):
|
|
30
|
+
if not isinstance(value, float):
|
|
31
|
+
return False
|
|
32
|
+
return math.isnan(value) or math.isinf(value)
|
|
33
|
+
|
|
27
34
|
@staticmethod
|
|
28
35
|
def _add_md5_compare_data(node_data, compare_data_dict):
|
|
29
36
|
precision_index = GraphConst.MAX_INDEX_KEY
|
|
@@ -48,6 +55,8 @@ class ModeAdapter:
|
|
|
48
55
|
for key, value in node_data.items():
|
|
49
56
|
if not isinstance(value, dict):
|
|
50
57
|
continue
|
|
58
|
+
if value.get(Const.MAX) is None:
|
|
59
|
+
continue
|
|
51
60
|
compare_data = compare_data_dict.get(key)
|
|
52
61
|
if compare_data:
|
|
53
62
|
headers = CompareConst.COMPARE_RESULT_HEADER
|
|
@@ -66,9 +75,13 @@ class ModeAdapter:
|
|
|
66
75
|
if thousandth is not None:
|
|
67
76
|
numbers.append(thousandth)
|
|
68
77
|
node_data[key] = value
|
|
78
|
+
if ModeAdapter._is_invalid(value.get(Const.MAX)) or ModeAdapter._is_invalid(value.get(Const.MIN)):
|
|
79
|
+
numbers.append(CompareConst.N_A)
|
|
69
80
|
# 双千指标都是None的异常情况
|
|
70
81
|
if not numbers:
|
|
71
82
|
min_thousandth = None
|
|
83
|
+
elif CompareConst.N_A in numbers:
|
|
84
|
+
min_thousandth = CompareConst.N_A
|
|
72
85
|
else:
|
|
73
86
|
min_thousandth = min(numbers + [min_thousandth])
|
|
74
87
|
return min_thousandth
|
|
@@ -80,6 +93,8 @@ class ModeAdapter:
|
|
|
80
93
|
for key, data_info in node_data.items():
|
|
81
94
|
if not isinstance(data_info, dict):
|
|
82
95
|
continue
|
|
96
|
+
if data_info.get(Const.MAX) is None:
|
|
97
|
+
continue
|
|
83
98
|
compare_data = compare_data_dict.get(key)
|
|
84
99
|
if compare_data:
|
|
85
100
|
# 对应比对结果csv的列
|
|
@@ -91,6 +106,8 @@ class ModeAdapter:
|
|
|
91
106
|
relative_err = str2float(data_info.get(item))
|
|
92
107
|
max_relative_err = max(max_relative_err, relative_err)
|
|
93
108
|
node_data[key] = data_info
|
|
109
|
+
if ModeAdapter._is_invalid(data_info.get(Const.MAX)) or ModeAdapter._is_invalid(data_info.get(Const.MIN)):
|
|
110
|
+
max_relative_err = GraphConst.MAX_INDEX_KEY
|
|
94
111
|
max_relative_err = 1 if max_relative_err > 1 else max_relative_err
|
|
95
112
|
return max_relative_err
|
|
96
113
|
|
|
@@ -132,7 +149,11 @@ class ModeAdapter:
|
|
|
132
149
|
ModeAdapter._check_list_len(compare_data_dict_list, 1)
|
|
133
150
|
min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
|
|
134
151
|
min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
|
|
135
|
-
if
|
|
152
|
+
if CompareConst.N_A == min_thousandth_out:
|
|
153
|
+
change_percentage = GraphConst.MAX_INDEX_KEY
|
|
154
|
+
elif CompareConst.N_A == min_thousandth_in:
|
|
155
|
+
change_percentage = GraphConst.MIN_INDEX_KEY
|
|
156
|
+
elif min_thousandth_in is not None and min_thousandth_out is not None:
|
|
136
157
|
change_percentage = min_thousandth_in - min_thousandth_out
|
|
137
158
|
else:
|
|
138
159
|
change_percentage = GraphConst.MIN_INDEX_KEY
|
|
@@ -140,6 +161,7 @@ class ModeAdapter:
|
|
|
140
161
|
else change_percentage
|
|
141
162
|
precision_index = GraphConst.MAX_INDEX_KEY \
|
|
142
163
|
if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
|
|
164
|
+
precision_index = self._ignore_precision_index(node.id, precision_index)
|
|
143
165
|
return precision_index, other_dict
|
|
144
166
|
|
|
145
167
|
def prepare_real_data(self, node):
|
|
@@ -176,3 +198,11 @@ class ModeAdapter:
|
|
|
176
198
|
CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
|
|
177
199
|
CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
|
|
178
200
|
return json.dumps(tips)
|
|
201
|
+
|
|
202
|
+
def _ignore_precision_index(self, node_id, precision_index):
|
|
203
|
+
node_id_split = node_id.split(Const.SEP)
|
|
204
|
+
if len(node_id_split) < 2:
|
|
205
|
+
return precision_index
|
|
206
|
+
if node_id.split(Const.SEP)[1] in GraphConst.IGNORE_PRECISION_INDEX:
|
|
207
|
+
return GraphConst.MAX_INDEX_KEY if self.compare_mode == GraphConst.MD5_COMPARE else GraphConst.MIN_INDEX_KEY
|
|
208
|
+
return precision_index
|
|
@@ -87,15 +87,15 @@ class BaseNode:
|
|
|
87
87
|
self.matched_node_link = ancestors
|
|
88
88
|
node.matched_node_link = ancestors
|
|
89
89
|
|
|
90
|
-
def to_dict(self):
|
|
90
|
+
def to_dict(self, compare_mode=None):
|
|
91
91
|
"""
|
|
92
92
|
输出数据
|
|
93
93
|
"""
|
|
94
94
|
result = {
|
|
95
95
|
'id': self.id,
|
|
96
96
|
'node_type': self.op.value,
|
|
97
|
-
'output_data': format_node_data(self.output_data, self.id),
|
|
98
|
-
'input_data': format_node_data(self.input_data, self.id),
|
|
97
|
+
'output_data': format_node_data(self.output_data, self.id, compare_mode),
|
|
98
|
+
'input_data': format_node_data(self.input_data, self.id, compare_mode),
|
|
99
99
|
'upnode': self.upnode.id if self.upnode else 'None',
|
|
100
100
|
'subnodes': [node.id for node in self.subnodes],
|
|
101
101
|
'matched_node_link': self.matched_node_link,
|