mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
- msprobe/README.md +8 -5
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +64 -13
- msprobe/core/common/framework_adapter.py +10 -1
- msprobe/core/common/utils.py +17 -0
- msprobe/core/compare/utils.py +26 -6
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +2 -16
- msprobe/core/service.py +5 -16
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +0 -13
- msprobe/docs/05.data_dump_PyTorch.md +1 -1
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
- msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/19.monitor.md +4 -4
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +0 -70
- msprobe/pytorch/debugger/debugger_config.py +0 -10
- msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
- msprobe/pytorch/hook_module/api_register.py +14 -3
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +10 -14
- msprobe/visualization/builder/graph_builder.py +2 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/db_utils.py +42 -18
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/visualization/graph_service.py +20 -10
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
|
|
|
35
35
|
class TensorConfig(BaseConfig):
|
|
36
36
|
def __init__(self, json_config):
|
|
37
37
|
super().__init__(json_config)
|
|
38
|
-
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
39
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
40
|
-
self.host = json_config.get("host", "")
|
|
41
|
-
self.port = json_config.get("port", -1)
|
|
42
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
-
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
44
38
|
self.check_config()
|
|
45
39
|
self._check_summary_mode()
|
|
46
40
|
self._check_file_format()
|
|
47
|
-
|
|
48
|
-
self._check_online_run_ut()
|
|
41
|
+
|
|
49
42
|
|
|
50
43
|
def _check_file_format(self):
|
|
51
44
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
52
45
|
raise Exception("file_format is invalid")
|
|
53
46
|
|
|
54
|
-
def _check_online_run_ut(self):
|
|
55
|
-
if not isinstance(self.online_run_ut, bool):
|
|
56
|
-
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
57
|
-
|
|
58
|
-
if not isinstance(self.online_run_ut_recompute, bool):
|
|
59
|
-
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
60
|
-
|
|
61
|
-
if self.nfs_path:
|
|
62
|
-
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
63
|
-
return
|
|
64
|
-
|
|
65
|
-
if self.tls_path:
|
|
66
|
-
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
-
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
-
if os.path.exists(crl_path):
|
|
72
|
-
check_file_or_directory_path(crl_path)
|
|
73
|
-
|
|
74
|
-
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
75
|
-
raise Exception(f"host: {self.host} is invalid.")
|
|
76
|
-
|
|
77
|
-
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
78
|
-
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
79
|
-
|
|
80
47
|
|
|
81
48
|
class StatisticsConfig(BaseConfig):
|
|
82
49
|
def __init__(self, json_config):
|
|
@@ -257,12 +224,7 @@ class RunUTConfig(BaseConfig):
|
|
|
257
224
|
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
258
225
|
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
259
226
|
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
260
|
-
|
|
261
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
262
|
-
self.host = json_config.get("host", "")
|
|
263
|
-
self.port = json_config.get("port", -1)
|
|
264
|
-
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
265
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
227
|
+
|
|
266
228
|
self.check_run_ut_config()
|
|
267
229
|
|
|
268
230
|
@classmethod
|
|
@@ -280,22 +242,11 @@ class RunUTConfig(BaseConfig):
|
|
|
280
242
|
if not os.path.exists(error_data_path):
|
|
281
243
|
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
282
244
|
|
|
283
|
-
@classmethod
|
|
284
|
-
def check_nfs_path_config(cls, nfs_path):
|
|
285
|
-
if nfs_path:
|
|
286
|
-
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
287
|
-
|
|
288
|
-
@classmethod
|
|
289
|
-
def check_tls_path_config(cls, tls_path):
|
|
290
|
-
if tls_path:
|
|
291
|
-
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
292
245
|
|
|
293
246
|
def check_run_ut_config(self):
|
|
294
247
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
295
248
|
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
296
249
|
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
297
|
-
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
298
|
-
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
299
250
|
|
|
300
251
|
|
|
301
252
|
class GradToolConfig(BaseConfig):
|
|
@@ -15,19 +15,20 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.utils import Const
|
|
17
17
|
from msprobe.core.service import BaseService
|
|
18
|
-
from msprobe.pytorch.attl_manager import ATTLManager
|
|
19
18
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
19
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
21
20
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
22
|
-
from msprobe.pytorch.hook_module.api_register import
|
|
21
|
+
from msprobe.pytorch.hook_module.api_register import (
|
|
22
|
+
get_api_register,
|
|
23
|
+
ApiTemplate,
|
|
24
|
+
redirect_wait,
|
|
25
|
+
reset_dist_collect_func
|
|
26
|
+
)
|
|
23
27
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
28
|
from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
25
29
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
26
30
|
from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
|
|
27
31
|
|
|
28
|
-
if torch_version_above_or_equal_2:
|
|
29
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
-
|
|
31
32
|
|
|
32
33
|
class PytorchService(BaseService):
|
|
33
34
|
@property
|
|
@@ -37,7 +38,7 @@ class PytorchService(BaseService):
|
|
|
37
38
|
@staticmethod
|
|
38
39
|
def _get_current_rank():
|
|
39
40
|
return get_rank_if_initialized()
|
|
40
|
-
|
|
41
|
+
|
|
41
42
|
def reset_status(self):
|
|
42
43
|
self._reset_status()
|
|
43
44
|
|
|
@@ -45,12 +46,10 @@ class PytorchService(BaseService):
|
|
|
45
46
|
self.logger = logger
|
|
46
47
|
self.api_register = get_api_register()
|
|
47
48
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
48
|
-
self.
|
|
49
|
-
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
49
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config)
|
|
50
50
|
self.api_template = ApiTemplate
|
|
51
51
|
|
|
52
52
|
def _register_hook(self):
|
|
53
|
-
self.attl_manager.attl_init()
|
|
54
53
|
if self._is_mix_level:
|
|
55
54
|
register_optimizer_hook(self.data_collector)
|
|
56
55
|
|
|
@@ -65,11 +64,8 @@ class PytorchService(BaseService):
|
|
|
65
64
|
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
66
65
|
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
67
66
|
|
|
68
|
-
def _run_ut_dispatch(self, status):
|
|
69
|
-
if torch_version_above_or_equal_2:
|
|
70
|
-
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
71
|
-
|
|
72
67
|
def _reset_status(self):
|
|
73
68
|
super()._reset_status()
|
|
74
69
|
ModuleProcesser.reset_module_stats()
|
|
75
70
|
HOOKModule.reset_module_stats()
|
|
71
|
+
reset_dist_collect_func()
|
|
@@ -298,8 +298,8 @@ class GraphBuilder:
|
|
|
298
298
|
no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
|
|
299
299
|
if not no_recompute_map:
|
|
300
300
|
return
|
|
301
|
-
#
|
|
302
|
-
no_recompute_ids_b =
|
|
301
|
+
# 拷贝非重计算节点字典用于反向模式
|
|
302
|
+
no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
|
|
303
303
|
|
|
304
304
|
del_indexes = []
|
|
305
305
|
for node_id, id_prefix in recompute_map.items():
|
|
@@ -146,6 +146,7 @@ class BaseGraphMerger:
|
|
|
146
146
|
GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
|
|
147
147
|
id_accumulation=True)
|
|
148
148
|
all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
|
|
149
|
+
all_collection_node.upnode = main_graph_result.graph.root
|
|
149
150
|
new_main_root_sub_nodes.append(all_collection_node)
|
|
150
151
|
# Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
|
|
151
152
|
origin_main_node_id = main_node.id
|
|
@@ -377,6 +378,12 @@ class PPMerger(BaseGraphMerger):
|
|
|
377
378
|
logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
|
|
378
379
|
'generate pp groups using parallel param "rank_size", "tp" and "pp".')
|
|
379
380
|
_, pp_groups = self.get_default_groups()
|
|
381
|
+
elif len(pp_groups[0]) != self.parallel_param.pp:
|
|
382
|
+
logger.warning(f'Based on Distributed Api (atch_isend_irecv, send, or isend), '
|
|
383
|
+
f'the resulting pp groups={pp_groups}, '
|
|
384
|
+
f'its length is not equal to the parallel param "pp"({self.parallel_param.pp}) you defined, '
|
|
385
|
+
f'generate pp groups using parallel param "rank_size", "tp" and "pp".')
|
|
386
|
+
_, pp_groups = self.get_default_groups()
|
|
380
387
|
logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
|
|
381
388
|
return pp_groups
|
|
382
389
|
|
|
@@ -657,6 +664,12 @@ class TPMerger(BaseGraphMerger):
|
|
|
657
664
|
logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
|
|
658
665
|
'generate tp groups using parallel param "rank_size", "tp" and "pp".')
|
|
659
666
|
tp_groups, _ = self.get_default_groups()
|
|
667
|
+
elif len(tp_groups[0]) != self.parallel_param.tp:
|
|
668
|
+
logger.warning(f'Based on Distributed Api (reduce_scatter or all_reduce), '
|
|
669
|
+
f'the resulting tp groups={tp_groups}, '
|
|
670
|
+
f'its length is not equal to the parallel param "tp"({self.parallel_param.tp}) you defined, '
|
|
671
|
+
f'generate tp groups using parallel param "rank_size", "tp" and "pp".')
|
|
672
|
+
tp_groups, _ = self.get_default_groups()
|
|
660
673
|
logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
|
|
661
674
|
return tp_groups
|
|
662
675
|
|
|
@@ -17,6 +17,7 @@ import os
|
|
|
17
17
|
import sqlite3
|
|
18
18
|
import json
|
|
19
19
|
import re
|
|
20
|
+
import time
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
21
22
|
from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker
|
|
22
23
|
from msprobe.core.common.const import FileCheckConst
|
|
@@ -133,33 +134,56 @@ def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False):
|
|
|
133
134
|
|
|
134
135
|
|
|
135
136
|
def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000):
|
|
137
|
+
max_retries = 10
|
|
138
|
+
initial_delay = 0.1
|
|
136
139
|
if not os.path.exists(db_path):
|
|
137
140
|
check_path_before_create(db_path)
|
|
138
141
|
else:
|
|
139
142
|
FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE,
|
|
140
143
|
FileCheckConst.DB_SUFFIX).common_check()
|
|
141
|
-
try:
|
|
142
|
-
conn = sqlite3.connect(db_path)
|
|
143
|
-
except sqlite3.Error as e:
|
|
144
|
-
logger.error(f"Unable to create database connection: {e}")
|
|
145
|
-
raise RuntimeError("Unable to create database connection") from e
|
|
146
144
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
145
|
+
retry_count = 0
|
|
146
|
+
current_delay = initial_delay
|
|
147
|
+
|
|
148
|
+
while retry_count <= max_retries:
|
|
149
|
+
conn = None
|
|
150
|
+
try:
|
|
151
|
+
conn = sqlite3.connect(db_path, timeout=30)
|
|
152
|
+
cursor = conn.cursor()
|
|
153
|
+
# 启用WAL模式提升多进程读写并发能力
|
|
154
|
+
cursor.execute("PRAGMA journal_mode=WAL")
|
|
155
|
+
cursor.execute("PRAGMA synchronous=NORMAL")
|
|
156
|
+
cursor.execute(create_table_sql)
|
|
154
157
|
for i in range(0, len(data), db_insert_size):
|
|
155
158
|
batch = data[i:i + db_insert_size]
|
|
156
159
|
cursor.executemany(insert_sql, batch)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
160
|
+
conn.commit()
|
|
161
|
+
return
|
|
162
|
+
except sqlite3.OperationalError as e:
|
|
163
|
+
if "database is locked" in str(e).lower():
|
|
164
|
+
retry_count += 1
|
|
165
|
+
if retry_count > max_retries:
|
|
166
|
+
logger.error(f"Database lock conflict retry attempts exhausted ({max_retries}): {e}")
|
|
167
|
+
raise RuntimeError(f"DB lock retry exhausted: {e}") from e
|
|
168
|
+
|
|
169
|
+
logger.warning(
|
|
170
|
+
f"DB lock conflict (retry {retry_count}/{max_retries}), wait {current_delay:.2f}s : {e}"
|
|
171
|
+
)
|
|
172
|
+
time.sleep(current_delay)
|
|
173
|
+
current_delay *= 2
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
logger.error(f"An sqlite3 error occurred: {e}")
|
|
177
|
+
raise e
|
|
178
|
+
except sqlite3.Error as e:
|
|
179
|
+
logger.error(f"An sqlite3 error occurred: {e}")
|
|
180
|
+
raise e
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger.error(f"An unknown error occurred: {e}")
|
|
183
|
+
raise e
|
|
184
|
+
finally:
|
|
185
|
+
if conn:
|
|
186
|
+
conn.close()
|
|
163
187
|
|
|
164
188
|
|
|
165
189
|
def add_table_index(db_path):
|
|
@@ -126,21 +126,25 @@ class Graph:
|
|
|
126
126
|
|
|
127
127
|
def get_sorted_nodes(self):
|
|
128
128
|
"""
|
|
129
|
-
通过深度优先遍历graph,获得排过序的node
|
|
129
|
+
通过深度优先遍历graph,获得排过序的node列表,使用栈实现避免超出递归深度问题
|
|
130
130
|
"""
|
|
131
131
|
visited = set()
|
|
132
132
|
order = []
|
|
133
|
+
stack = [(self.root, False)]
|
|
133
134
|
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
while stack:
|
|
136
|
+
node, processed = stack.pop()
|
|
136
137
|
if node.id in visited:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
138
|
+
continue
|
|
139
|
+
if processed:
|
|
140
|
+
visited.add(node.id)
|
|
141
|
+
order.append(node)
|
|
142
|
+
else:
|
|
143
|
+
stack.append((node, True))
|
|
144
|
+
for sub_node in reversed(node.subnodes):
|
|
145
|
+
if sub_node.id not in visited:
|
|
146
|
+
stack.append((sub_node, False))
|
|
142
147
|
|
|
143
|
-
visit(self.root)
|
|
144
148
|
return order
|
|
145
149
|
|
|
146
150
|
def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
|
|
@@ -242,11 +242,15 @@ def _compare_graph_ranks(input_param, args, step=None):
|
|
|
242
242
|
def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
|
|
243
243
|
dump_rank_n = input_param.get('npu_path')
|
|
244
244
|
dump_rank_b = input_param.get('bench_path')
|
|
245
|
-
npu_ranks =
|
|
246
|
-
bench_ranks =
|
|
245
|
+
npu_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_n, Const.RANK))
|
|
246
|
+
bench_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_b, Const.RANK))
|
|
247
247
|
if npu_ranks != bench_ranks:
|
|
248
|
-
|
|
249
|
-
|
|
248
|
+
intersection_ranks = sort_rank_number_strings(list(set(npu_ranks) & set(bench_ranks)))
|
|
249
|
+
if not intersection_ranks:
|
|
250
|
+
logger.error('The ranks in the two runs are completely different. Unable to match the ranks.')
|
|
251
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
252
|
+
npu_ranks = intersection_ranks
|
|
253
|
+
bench_ranks = intersection_ranks
|
|
250
254
|
compare_graph_results = []
|
|
251
255
|
if is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
252
256
|
mp_task_dict = {}
|
|
@@ -282,12 +286,16 @@ def _compare_graph_steps(input_param, args):
|
|
|
282
286
|
dump_step_n = input_param.get('npu_path')
|
|
283
287
|
dump_step_b = input_param.get('bench_path')
|
|
284
288
|
|
|
285
|
-
npu_steps =
|
|
286
|
-
bench_steps =
|
|
289
|
+
npu_steps = check_and_return_dir_contents(dump_step_n, Const.STEP)
|
|
290
|
+
bench_steps = check_and_return_dir_contents(dump_step_b, Const.STEP)
|
|
287
291
|
|
|
288
292
|
if npu_steps != bench_steps:
|
|
289
|
-
|
|
290
|
-
|
|
293
|
+
intersection_steps = sort_rank_number_strings(list(set(npu_steps) & set(bench_steps)))
|
|
294
|
+
|
|
295
|
+
if not intersection_steps:
|
|
296
|
+
logger.error('The steps in the two runs are completely different. Unable to match the steps.')
|
|
297
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
298
|
+
npu_steps = intersection_steps
|
|
291
299
|
|
|
292
300
|
args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
|
|
293
301
|
|
|
@@ -355,8 +363,10 @@ def _build_graph_steps(dump_steps_path, args):
|
|
|
355
363
|
_build_graph_ranks(dump_ranks_path, args, step)
|
|
356
364
|
|
|
357
365
|
|
|
358
|
-
def _compare_and_export_graph(graph_task_info, input_param, args):
|
|
366
|
+
def _compare_and_export_graph(graph_task_info, input_param, args, step=None):
|
|
359
367
|
result = _run_graph_compare(graph_task_info, input_param, args)
|
|
368
|
+
if step is not None:
|
|
369
|
+
result.step = get_step_or_rank_int(step)
|
|
360
370
|
return _export_compare_graph_result(args, result)
|
|
361
371
|
|
|
362
372
|
|
|
@@ -413,7 +423,7 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
|
|
|
413
423
|
_build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
|
|
414
424
|
f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
|
|
415
425
|
export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
|
|
416
|
-
args=(graph_task_info, input_param, serializable_args),
|
|
426
|
+
args=(graph_task_info, input_param, serializable_args, step),
|
|
417
427
|
error_callback=err_call))
|
|
418
428
|
export_res_list = [res.get() for res in export_res_task_list]
|
|
419
429
|
if any(export_res_list):
|