mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
- msprobe/README.md +7 -5
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +105 -27
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +2 -1
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +37 -3
- msprobe/core/service.py +18 -5
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +3 -3
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +170 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- msprobe/visualization/utils.py +11 -1
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
|
@@ -26,8 +26,7 @@ def read_npy_data(dir_path, file_name):
|
|
|
26
26
|
return None
|
|
27
27
|
|
|
28
28
|
data_path = os.path.join(dir_path, file_name)
|
|
29
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
30
|
-
FileCheckConst.NUMPY_SUFFIX, False)
|
|
29
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX)
|
|
31
30
|
data_path = path_checker.common_check()
|
|
32
31
|
data_value = load_npy(data_path)
|
|
33
32
|
return data_value
|
|
@@ -182,7 +182,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
182
182
|
with ThreadSafe():
|
|
183
183
|
instance.service.step()
|
|
184
184
|
if is_graph_mode_cell_dump_allowed(instance.config):
|
|
185
|
-
GraphModeCellDump.step()
|
|
185
|
+
GraphModeCellDump.step(instance.config.dump_path, instance.config.step, instance.config.task)
|
|
186
186
|
if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
|
|
187
187
|
_dump_step(1)
|
|
188
188
|
if cls._is_kernel_dump() and _msprobe_c:
|
|
@@ -46,9 +46,11 @@ KEY_FORWARD = CoreConst.FORWARD
|
|
|
46
46
|
KEY_BACKWARD = CoreConst.BACKWARD
|
|
47
47
|
KEY_INPUT = CoreConst.INPUT
|
|
48
48
|
KEY_OUTPUT = CoreConst.OUTPUT
|
|
49
|
-
KEY_DUMP_TENSOR_DATA = "
|
|
49
|
+
KEY_DUMP_TENSOR_DATA = "dump_tensor_data/"
|
|
50
50
|
KEY_STATISTIC_CSV = "statistic.csv"
|
|
51
51
|
KEY_TD_FLAG = "td_flag"
|
|
52
|
+
# 设置落盘文件检测超时时间
|
|
53
|
+
TIMEOUT = 600
|
|
52
54
|
td = ops.TensorDump()
|
|
53
55
|
if (ms.__version__ >= "2.5.0"):
|
|
54
56
|
td_in = ops.TensorDump("in")
|
|
@@ -574,28 +576,33 @@ def generate_stack_info(path):
|
|
|
574
576
|
logger.info(f"Stack data saved to {json_path}")
|
|
575
577
|
|
|
576
578
|
|
|
577
|
-
def is_download_finished(directory,
|
|
579
|
+
def is_download_finished(directory, save_flag):
|
|
578
580
|
"""
|
|
579
581
|
判断指定目录在一段时间后是否有数据被下载完成
|
|
580
582
|
:param directory: 指定目录的路径
|
|
581
|
-
:param
|
|
583
|
+
:param save_flag: 数据落盘完成后的标志文件
|
|
582
584
|
:return: 如有数据被下载完成返回 True,否则返回 False
|
|
583
585
|
"""
|
|
586
|
+
# 设定一定的延迟间隔,避免频繁进行磁盘的io读取操作
|
|
587
|
+
time.sleep(0.5)
|
|
588
|
+
logger.info("Waiting for download...")
|
|
584
589
|
# 检查目录是否存在
|
|
585
590
|
if not os.path.exists(directory):
|
|
586
591
|
logger.warning(f"The specified directory {directory} does not exist.")
|
|
587
592
|
return False
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
593
|
+
|
|
594
|
+
# 遍历当前目录中的所有条目
|
|
595
|
+
for entry_path in os.listdir(directory):
|
|
596
|
+
if entry_path.startswith(save_flag):
|
|
597
|
+
return True
|
|
598
|
+
|
|
599
|
+
return False
|
|
600
|
+
|
|
596
601
|
|
|
602
|
+
def process_step(dump_path, flag_path, step, step_list):
|
|
603
|
+
if step not in step_list:
|
|
604
|
+
return
|
|
597
605
|
|
|
598
|
-
def process(dump_path):
|
|
599
606
|
if not os.path.exists(dump_path):
|
|
600
607
|
logger.warning('No grap cell data is dumped.')
|
|
601
608
|
create_directory(dump_path)
|
|
@@ -606,32 +613,38 @@ def process(dump_path):
|
|
|
606
613
|
if rank_id is not None:
|
|
607
614
|
rank_dir = CoreConst.RANK + str(rank_id)
|
|
608
615
|
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
616
|
+
step_dir = CoreConst.STEP + str(step)
|
|
617
|
+
|
|
618
|
+
step_path = os.path.join(dump_path, step_dir)
|
|
619
|
+
rank_path = os.path.join(step_path, rank_dir)
|
|
620
|
+
npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
|
|
621
|
+
save_finish_flag = f"step_{step}"
|
|
622
|
+
start_time = time.time()
|
|
623
|
+
while True:
|
|
624
|
+
is_finished = is_download_finished(flag_path, save_finish_flag)
|
|
625
|
+
if not is_finished:
|
|
626
|
+
logger.info("There is data being downloaded in the specified directory, continue checking...")
|
|
627
|
+
else:
|
|
628
|
+
logger.info("There is no data being downloaded in the specified directory, Stop checking.")
|
|
629
|
+
break
|
|
630
|
+
elapsed_time = time.time() - start_time
|
|
631
|
+
if elapsed_time > TIMEOUT:
|
|
632
|
+
logger.error(f"Check timed out after {TIMEOUT} seconds. Exiting.")
|
|
633
|
+
return
|
|
634
|
+
logger.info(f"==========Start processing step_{step}'s data that has already been stored on the disk!==========")
|
|
635
|
+
rename_filename(path=npy_path)
|
|
636
|
+
generate_construct(npy_path)
|
|
637
|
+
generate_dump_info(npy_path)
|
|
638
|
+
generate_stack_info(npy_path)
|
|
639
|
+
# 单卡场景,rank目录名称为rank
|
|
640
|
+
if rank_id is None:
|
|
641
|
+
new_rank_path = os.path.join(step_path, CoreConst.RANK)
|
|
642
|
+
try:
|
|
643
|
+
move_directory(rank_path, new_rank_path)
|
|
644
|
+
logger.info(f"Directory was successfully renamed to: {new_rank_path}")
|
|
645
|
+
except Exception as e:
|
|
646
|
+
logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
|
|
647
|
+
logger.info(f"==========Step_{step}'s JSON file generation completed!==========")
|
|
635
648
|
|
|
636
649
|
|
|
637
650
|
# 删除csv文件中每行数据最后面的逗号
|
|
@@ -689,7 +702,10 @@ def merge_file(dump_path, rank_dir, file_dict):
|
|
|
689
702
|
" and the index is out of bounds.")
|
|
690
703
|
|
|
691
704
|
|
|
692
|
-
def
|
|
705
|
+
def process_statistics_step(dump_path, step, step_list):
|
|
706
|
+
if step_list and step not in step_list:
|
|
707
|
+
return
|
|
708
|
+
|
|
693
709
|
if not os.path.exists(dump_path):
|
|
694
710
|
logger.warning('No grap cell data is dumped.')
|
|
695
711
|
create_directory(dump_path)
|
|
@@ -723,25 +739,24 @@ def process_statistics(dump_path):
|
|
|
723
739
|
|
|
724
740
|
rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '')
|
|
725
741
|
dir_list = os.listdir(dump_path)
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
logger.info("==========JSON file generation completed!==========")
|
|
742
|
+
step_dir = CoreConst.STEP + str(step)
|
|
743
|
+
step_path = os.path.join(dump_path, step_dir)
|
|
744
|
+
rank_path = os.path.join(step_path, rank_dir)
|
|
745
|
+
csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
|
|
746
|
+
logger.info("==========Start processing data csv!==========")
|
|
747
|
+
generate_construct(csv_path)
|
|
748
|
+
generate_dump_info(csv_path)
|
|
749
|
+
generate_stack_info(csv_path)
|
|
750
|
+
remove_path(rank_path_kbk)
|
|
751
|
+
# 单卡场景,rank目录名称为rank
|
|
752
|
+
if rank_id is None:
|
|
753
|
+
new_rank_path = os.path.join(step_path, CoreConst.RANK)
|
|
754
|
+
try:
|
|
755
|
+
move_directory(rank_path, new_rank_path)
|
|
756
|
+
logger.info(f"Directory was successfully renamed to: {new_rank_path}")
|
|
757
|
+
except Exception as e:
|
|
758
|
+
logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
|
|
759
|
+
logger.info("==========JSON file generation completed!==========")
|
|
745
760
|
|
|
746
761
|
|
|
747
762
|
def get_yaml_keys(yaml_data):
|
|
@@ -922,7 +937,3 @@ def start(config: CellDumpConfig):
|
|
|
922
937
|
cell.data_mode = data_mode
|
|
923
938
|
|
|
924
939
|
logger.info("==========The cell_dump_process_start phase is Finished!==========")
|
|
925
|
-
if dump_task == CoreConst.TENSOR:
|
|
926
|
-
atexit.register(process, dump_path=dump_path)
|
|
927
|
-
if dump_task == CoreConst.STATISTICS:
|
|
928
|
-
atexit.register(process_statistics, dump_path=dump_path)
|
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
17
|
+
import glob
|
|
18
|
+
import tempfile
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore import hal, ops, Tensor
|
|
20
21
|
from mindspore.ops.primitive import _run_op
|
|
@@ -28,6 +29,7 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
|
|
|
28
29
|
import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
|
|
29
30
|
|
|
30
31
|
tensordump_flag = True
|
|
32
|
+
DEFAULT_RANK_DIR = "rank0"
|
|
31
33
|
try:
|
|
32
34
|
from mindspore._c_expression import _tensordump_set_step
|
|
33
35
|
except ImportError:
|
|
@@ -41,8 +43,6 @@ except ImportError:
|
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class GraphModeCellDump:
|
|
44
|
-
task = CoreConst.STATISTICS
|
|
45
|
-
|
|
46
46
|
def __init__(self, config: DebuggerConfig, model, strict=True):
|
|
47
47
|
self.net = model
|
|
48
48
|
self.white_list = []
|
|
@@ -55,29 +55,40 @@ class GraphModeCellDump:
|
|
|
55
55
|
self.list = config.list
|
|
56
56
|
self.data_mode = config.data_mode
|
|
57
57
|
self.file_format = config.file_format
|
|
58
|
-
GraphModeCellDump.task = config.task
|
|
59
58
|
self.summary_mode = config.summary_mode
|
|
59
|
+
self.task = config.task
|
|
60
60
|
self.check_config(strict)
|
|
61
61
|
self.set_step()
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
|
-
def step():
|
|
64
|
+
def step(dump_path, step_list, task):
|
|
65
65
|
# 更新TensorDump Step
|
|
66
|
-
if
|
|
66
|
+
if task == CoreConst.TENSOR:
|
|
67
67
|
hal.synchronize()
|
|
68
68
|
temp_tensor = ms.Tensor([1], dtype=ms.float32)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
rank_id = os.environ.get('RANK_ID')
|
|
70
|
+
rank_dir = DEFAULT_RANK_DIR
|
|
71
|
+
|
|
72
|
+
if rank_id is not None:
|
|
73
|
+
rank_dir = CoreConst.RANK + str(rank_id)
|
|
74
|
+
|
|
75
|
+
with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
|
|
76
|
+
save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
|
|
77
|
+
_run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
|
|
78
|
+
step_flag = "<tensordump-update-step>"
|
|
79
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
80
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
81
|
+
cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
|
|
72
82
|
|
|
73
83
|
# 更新静态图KBK dump的step数
|
|
74
|
-
if
|
|
84
|
+
if task == CoreConst.STATISTICS:
|
|
75
85
|
if not graph_step_flag:
|
|
76
86
|
raise Exception(
|
|
77
87
|
"Importing _dump_step failed, "
|
|
78
88
|
"please use the latest version package of MindSpore."
|
|
79
89
|
)
|
|
80
90
|
_dump_step(1)
|
|
91
|
+
cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
|
|
81
92
|
|
|
82
93
|
def check_config(self, strict):
|
|
83
94
|
if not self.net:
|
|
@@ -203,10 +203,12 @@ class MindsporeHookManager(BaseHookManager):
|
|
|
203
203
|
return
|
|
204
204
|
|
|
205
205
|
with ThreadSafe():
|
|
206
|
+
original_state = self.ensure_gc_enabled()
|
|
206
207
|
BaseHookManager.inner_switch[tid] = True
|
|
207
208
|
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
208
209
|
self.data_collector.update_api_or_module_name(full_name)
|
|
209
210
|
self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input)
|
|
210
211
|
BaseHookManager.inner_switch[tid] = False
|
|
212
|
+
self.restore_gc_state(original_state)
|
|
211
213
|
|
|
212
214
|
return backward_pre_hook
|
msprobe/msprobe.py
CHANGED
|
@@ -14,16 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import argparse
|
|
17
|
-
import sys
|
|
18
17
|
import importlib.util
|
|
18
|
+
import sys
|
|
19
19
|
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.common.file_utils import root_privilege_warning
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
22
|
-
from msprobe.core.compare.utils import _compare_parser
|
|
23
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
24
24
|
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
25
|
-
from msprobe.core.
|
|
26
|
-
|
|
25
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
26
|
+
from msprobe.core.config_check.config_check_cli import _config_checking_parser, _run_config_checking_command
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def is_module_available(module_name):
|
|
@@ -64,6 +64,8 @@ def main():
|
|
|
64
64
|
if len(sys.argv) < 4:
|
|
65
65
|
parser.print_help()
|
|
66
66
|
sys.exit(0)
|
|
67
|
+
|
|
68
|
+
root_privilege_warning()
|
|
67
69
|
framework_args = parser.parse_args(sys.argv[1:3])
|
|
68
70
|
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
69
71
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
@@ -24,7 +24,8 @@ from msprobe.pytorch.pt_config import RunUTConfig
|
|
|
24
24
|
|
|
25
25
|
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
26
|
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
-
'black_list', 'error_data_path'])
|
|
27
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
28
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class Config:
|
|
@@ -45,7 +46,13 @@ class Config:
|
|
|
45
46
|
'white_list': list,
|
|
46
47
|
'black_list': list,
|
|
47
48
|
'error_data_path': str,
|
|
48
|
-
'precision': int
|
|
49
|
+
'precision': int,
|
|
50
|
+
'is_online': bool,
|
|
51
|
+
'nfs_path': str,
|
|
52
|
+
'host': str,
|
|
53
|
+
'port': int,
|
|
54
|
+
'rank_list': list,
|
|
55
|
+
'tls_path': str
|
|
49
56
|
}
|
|
50
57
|
if key not in validators:
|
|
51
58
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
@@ -61,6 +68,10 @@ class Config:
|
|
|
61
68
|
RunUTConfig.check_filter_list_config(key, value)
|
|
62
69
|
if key == 'error_data_path':
|
|
63
70
|
RunUTConfig.check_error_data_path_config(value)
|
|
71
|
+
if key == 'nfs_path':
|
|
72
|
+
RunUTConfig.check_nfs_path_config(value)
|
|
73
|
+
if key == 'tls_path':
|
|
74
|
+
RunUTConfig.check_tls_path_config(value)
|
|
64
75
|
return value
|
|
65
76
|
|
|
66
77
|
|
|
@@ -74,6 +85,12 @@ class CheckerConfig:
|
|
|
74
85
|
self.white_list = msCheckerConfig.white_list
|
|
75
86
|
self.black_list = msCheckerConfig.black_list
|
|
76
87
|
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
+
self.is_online = msCheckerConfig.is_online
|
|
89
|
+
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
+
self.host = msCheckerConfig.host
|
|
91
|
+
self.port = msCheckerConfig.port
|
|
92
|
+
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
+
self.tls_path = msCheckerConfig.tls_path
|
|
77
94
|
|
|
78
95
|
if task_config:
|
|
79
96
|
self.load_config(task_config)
|
|
@@ -82,7 +99,22 @@ class CheckerConfig:
|
|
|
82
99
|
self.white_list = task_config.white_list
|
|
83
100
|
self.black_list = task_config.black_list
|
|
84
101
|
self.error_data_path = task_config.error_data_path
|
|
102
|
+
self.is_online = task_config.is_online
|
|
103
|
+
self.nfs_path = task_config.nfs_path
|
|
104
|
+
self.host = task_config.host
|
|
105
|
+
self.port = task_config.port
|
|
106
|
+
self.rank_list = task_config.rank_list
|
|
107
|
+
self.tls_path = task_config.tls_path
|
|
85
108
|
|
|
109
|
+
def get_online_config(self):
|
|
110
|
+
return OnlineConfig(
|
|
111
|
+
is_online=self.is_online,
|
|
112
|
+
nfs_path=self.nfs_path,
|
|
113
|
+
host=self.host,
|
|
114
|
+
port=self.port,
|
|
115
|
+
rank_list=self.rank_list,
|
|
116
|
+
tls_path=self.tls_path
|
|
117
|
+
)
|
|
86
118
|
|
|
87
119
|
def get_run_ut_config(self, **config_params):
|
|
88
120
|
return RunUtConfig(
|
|
@@ -95,5 +127,6 @@ class CheckerConfig:
|
|
|
95
127
|
real_data_path=config_params.get('real_data_path'),
|
|
96
128
|
white_list=self.white_list.copy() if self.white_list else [],
|
|
97
129
|
black_list=self.black_list.copy() if self.black_list else [],
|
|
98
|
-
error_data_path=config_params.get('error_data_path')
|
|
130
|
+
error_data_path=config_params.get('error_data_path'),
|
|
131
|
+
online_config=self.get_online_config()
|
|
99
132
|
)
|
|
@@ -117,6 +117,30 @@ def api_precision_compare(config):
|
|
|
117
117
|
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
def online_api_precision_compare(online_config):
|
|
121
|
+
rank = online_config.rank
|
|
122
|
+
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
|
|
123
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
124
|
+
details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
|
|
125
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
126
|
+
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
127
|
+
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
128
|
+
if not os.path.exists(result_csv_path):
|
|
129
|
+
write_csv(result_csv_title, result_csv_path)
|
|
130
|
+
if not os.path.exists(details_csv_path):
|
|
131
|
+
write_csv(detail_csv_title, details_csv_path)
|
|
132
|
+
config = CompareConfig("", "", result_csv_path, details_csv_path)
|
|
133
|
+
try:
|
|
134
|
+
npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
|
|
135
|
+
check_csv_columns(npu_data.columns, "npu_csv")
|
|
136
|
+
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
137
|
+
analyse_csv(npu_data, gpu_data, config)
|
|
138
|
+
except Exception as err:
|
|
139
|
+
logger.error(f"Online api precision compare Error: {str(err)}")
|
|
140
|
+
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
141
|
+
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
142
|
+
|
|
143
|
+
|
|
120
144
|
def analyse_csv(npu_data, gpu_data, config):
|
|
121
145
|
forward_status, backward_status = [], []
|
|
122
146
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
@@ -66,6 +66,13 @@ class Comparator:
|
|
|
66
66
|
self.save_path_list = [result_csv_path]
|
|
67
67
|
self.detail_save_path_list = [details_csv_path]
|
|
68
68
|
|
|
69
|
+
if config and config.online_config.is_online:
|
|
70
|
+
self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
|
|
71
|
+
self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
|
|
72
|
+
self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
73
|
+
self.detail_save_path_list = \
|
|
74
|
+
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
75
|
+
|
|
69
76
|
self.registry = self._register_compare_func()
|
|
70
77
|
|
|
71
78
|
if not is_continue_run_ut:
|
|
@@ -238,8 +245,9 @@ class Comparator:
|
|
|
238
245
|
self.write_detail_csv(args)
|
|
239
246
|
|
|
240
247
|
|
|
241
|
-
def compare_output(self, full_api_name, data_info):
|
|
248
|
+
def compare_output(self, full_api_name, data_info, is_online=False):
|
|
242
249
|
"""Get compare result and write to result and detail csv.
|
|
250
|
+
is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
|
|
243
251
|
"""
|
|
244
252
|
_, api_name = extract_basic_api_segments(full_api_name)
|
|
245
253
|
if not api_name:
|
|
@@ -272,7 +280,9 @@ class Comparator:
|
|
|
272
280
|
fwd_compare_alg_results,
|
|
273
281
|
bwd_compare_alg_results,
|
|
274
282
|
data_info.rank)
|
|
275
|
-
|
|
283
|
+
if is_online:
|
|
284
|
+
# get run_ut compare detail
|
|
285
|
+
return self._get_run_ut_detail(result_info)
|
|
276
286
|
self.record_results(result_info)
|
|
277
287
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
278
288
|
or bwd_success_status == CompareConst.SPACE
|