mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/utils.py +26 -6
  9. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  10. msprobe/core/hook_manager.py +2 -16
  11. msprobe/core/service.py +5 -16
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +0 -13
  14. msprobe/docs/05.data_dump_PyTorch.md +1 -1
  15. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
  16. msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
  17. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  18. msprobe/docs/19.monitor.md +4 -4
  19. msprobe/docs/21.visualization_PyTorch.md +1 -1
  20. msprobe/docs/25.tool_function_introduction.md +0 -1
  21. msprobe/docs/32.ckpt_compare.md +5 -5
  22. msprobe/mindspore/monitor/module_hook.py +17 -20
  23. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  24. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  25. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  26. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  27. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  28. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  29. msprobe/pytorch/common/utils.py +0 -70
  30. msprobe/pytorch/debugger/debugger_config.py +0 -10
  31. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  32. msprobe/pytorch/hook_module/api_register.py +14 -3
  33. msprobe/pytorch/monitor/module_hook.py +16 -34
  34. msprobe/pytorch/pt_config.py +2 -51
  35. msprobe/pytorch/pytorch_service.py +10 -14
  36. msprobe/visualization/builder/graph_builder.py +2 -2
  37. msprobe/visualization/builder/graph_merger.py +13 -0
  38. msprobe/visualization/db_utils.py +42 -18
  39. msprobe/visualization/graph/graph.py +13 -9
  40. msprobe/visualization/graph_service.py +20 -10
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  43. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  44. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  45. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  46. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  47. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  48. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  49. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  50. msprobe/pytorch/attl_manager.py +0 -65
  51. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
  52. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
  53. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
  54. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
35
35
  class TensorConfig(BaseConfig):
36
36
  def __init__(self, json_config):
37
37
  super().__init__(json_config)
38
- self.online_run_ut = json_config.get("online_run_ut", False)
39
- self.nfs_path = json_config.get("nfs_path", "")
40
- self.host = json_config.get("host", "")
41
- self.port = json_config.get("port", -1)
42
- self.tls_path = json_config.get("tls_path", "./")
43
- self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
44
38
  self.check_config()
45
39
  self._check_summary_mode()
46
40
  self._check_file_format()
47
- if self.online_run_ut:
48
- self._check_online_run_ut()
41
+
49
42
 
50
43
  def _check_file_format(self):
51
44
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
52
45
  raise Exception("file_format is invalid")
53
46
 
54
- def _check_online_run_ut(self):
55
- if not isinstance(self.online_run_ut, bool):
56
- raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
57
-
58
- if not isinstance(self.online_run_ut_recompute, bool):
59
- raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
60
-
61
- if self.nfs_path:
62
- check_file_or_directory_path(self.nfs_path, isdir=True)
63
- return
64
-
65
- if self.tls_path:
66
- check_file_or_directory_path(self.tls_path, isdir=True)
67
- check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
68
- check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
69
- check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
- crl_path = os.path.join(self.tls_path, "crl.pem")
71
- if os.path.exists(crl_path):
72
- check_file_or_directory_path(crl_path)
73
-
74
- if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
75
- raise Exception(f"host: {self.host} is invalid.")
76
-
77
- if not isinstance(self.port, int) or not (0 < self.port <= 65535):
78
- raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
79
-
80
47
 
81
48
  class StatisticsConfig(BaseConfig):
82
49
  def __init__(self, json_config):
@@ -257,12 +224,7 @@ class RunUTConfig(BaseConfig):
257
224
  self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
258
225
  self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
259
226
  self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
260
- self.is_online = json_config.get("is_online", False)
261
- self.nfs_path = json_config.get("nfs_path", "")
262
- self.host = json_config.get("host", "")
263
- self.port = json_config.get("port", -1)
264
- self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
265
- self.tls_path = json_config.get("tls_path", "./")
227
+
266
228
  self.check_run_ut_config()
267
229
 
268
230
  @classmethod
@@ -280,22 +242,11 @@ class RunUTConfig(BaseConfig):
280
242
  if not os.path.exists(error_data_path):
281
243
  raise Exception("error_data_path: %s does not exist" % error_data_path)
282
244
 
283
- @classmethod
284
- def check_nfs_path_config(cls, nfs_path):
285
- if nfs_path:
286
- FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
287
-
288
- @classmethod
289
- def check_tls_path_config(cls, tls_path):
290
- if tls_path:
291
- FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
292
245
 
293
246
  def check_run_ut_config(self):
294
247
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
295
248
  RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
296
249
  RunUTConfig.check_error_data_path_config(self.error_data_path)
297
- RunUTConfig.check_nfs_path_config(self.nfs_path)
298
- RunUTConfig.check_tls_path_config(self.tls_path)
299
250
 
300
251
 
301
252
  class GradToolConfig(BaseConfig):
@@ -15,19 +15,20 @@
15
15
 
16
16
  from msprobe.core.common.utils import Const
17
17
  from msprobe.core.service import BaseService
18
- from msprobe.pytorch.attl_manager import ATTLManager
19
18
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
19
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
21
20
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
22
- from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
21
+ from msprobe.pytorch.hook_module.api_register import (
22
+ get_api_register,
23
+ ApiTemplate,
24
+ redirect_wait,
25
+ reset_dist_collect_func
26
+ )
23
27
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
28
  from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
25
29
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
26
30
  from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
27
31
 
28
- if torch_version_above_or_equal_2:
29
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
30
-
31
32
 
32
33
  class PytorchService(BaseService):
33
34
  @property
@@ -37,7 +38,7 @@ class PytorchService(BaseService):
37
38
  @staticmethod
38
39
  def _get_current_rank():
39
40
  return get_rank_if_initialized()
40
-
41
+
41
42
  def reset_status(self):
42
43
  self._reset_status()
43
44
 
@@ -45,12 +46,10 @@ class PytorchService(BaseService):
45
46
  self.logger = logger
46
47
  self.api_register = get_api_register()
47
48
  self.module_processor = ModuleProcesser(self.data_collector.scope)
48
- self.attl_manager = ATTLManager(self.config)
49
- self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
49
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config)
50
50
  self.api_template = ApiTemplate
51
51
 
52
52
  def _register_hook(self):
53
- self.attl_manager.attl_init()
54
53
  if self._is_mix_level:
55
54
  register_optimizer_hook(self.data_collector)
56
55
 
@@ -65,11 +64,8 @@ class PytorchService(BaseService):
65
64
  self.module_processor.register_module_hook(self.model, self.build_hook)
66
65
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
67
66
 
68
- def _run_ut_dispatch(self, status):
69
- if torch_version_above_or_equal_2:
70
- run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
71
-
72
67
  def _reset_status(self):
73
68
  super()._reset_status()
74
69
  ModuleProcesser.reset_module_stats()
75
70
  HOOKModule.reset_module_stats()
71
+ reset_dist_collect_func()
@@ -298,8 +298,8 @@ class GraphBuilder:
298
298
  no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
299
299
  if not no_recompute_map:
300
300
  return
301
- # 深拷贝非重计算节点字典用于反向模式
302
- no_recompute_ids_b = copy.deepcopy(no_recompute_map)
301
+ # 拷贝非重计算节点字典用于反向模式
302
+ no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
303
303
 
304
304
  del_indexes = []
305
305
  for node_id, id_prefix in recompute_map.items():
@@ -146,6 +146,7 @@ class BaseGraphMerger:
146
146
  GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
147
147
  id_accumulation=True)
148
148
  all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
149
+ all_collection_node.upnode = main_graph_result.graph.root
149
150
  new_main_root_sub_nodes.append(all_collection_node)
150
151
  # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
151
152
  origin_main_node_id = main_node.id
@@ -377,6 +378,12 @@ class PPMerger(BaseGraphMerger):
377
378
  logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
378
379
  'generate pp groups using parallel param "rank_size", "tp" and "pp".')
379
380
  _, pp_groups = self.get_default_groups()
381
+ elif len(pp_groups[0]) != self.parallel_param.pp:
382
+ logger.warning(f'Based on Distributed Api (atch_isend_irecv, send, or isend), '
383
+ f'the resulting pp groups={pp_groups}, '
384
+ f'its length is not equal to the parallel param "pp"({self.parallel_param.pp}) you defined, '
385
+ f'generate pp groups using parallel param "rank_size", "tp" and "pp".')
386
+ _, pp_groups = self.get_default_groups()
380
387
  logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
381
388
  return pp_groups
382
389
 
@@ -657,6 +664,12 @@ class TPMerger(BaseGraphMerger):
657
664
  logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
658
665
  'generate tp groups using parallel param "rank_size", "tp" and "pp".')
659
666
  tp_groups, _ = self.get_default_groups()
667
+ elif len(tp_groups[0]) != self.parallel_param.tp:
668
+ logger.warning(f'Based on Distributed Api (reduce_scatter or all_reduce), '
669
+ f'the resulting tp groups={tp_groups}, '
670
+ f'its length is not equal to the parallel param "tp"({self.parallel_param.tp}) you defined, '
671
+ f'generate tp groups using parallel param "rank_size", "tp" and "pp".')
672
+ tp_groups, _ = self.get_default_groups()
660
673
  logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
661
674
  return tp_groups
662
675
 
@@ -17,6 +17,7 @@ import os
17
17
  import sqlite3
18
18
  import json
19
19
  import re
20
+ import time
20
21
  from msprobe.core.common.log import logger
21
22
  from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker
22
23
  from msprobe.core.common.const import FileCheckConst
@@ -133,33 +134,56 @@ def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False):
133
134
 
134
135
 
135
136
  def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000):
137
+ max_retries = 10
138
+ initial_delay = 0.1
136
139
  if not os.path.exists(db_path):
137
140
  check_path_before_create(db_path)
138
141
  else:
139
142
  FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE,
140
143
  FileCheckConst.DB_SUFFIX).common_check()
141
- try:
142
- conn = sqlite3.connect(db_path)
143
- except sqlite3.Error as e:
144
- logger.error(f"Unable to create database connection: {e}")
145
- raise RuntimeError("Unable to create database connection") from e
146
144
 
147
- try:
148
- cursor = conn.cursor()
149
- cursor.execute(create_table_sql)
150
- if len(data) == 1:
151
- cursor.execute(insert_sql, data[0])
152
- conn.commit()
153
- else:
145
+ retry_count = 0
146
+ current_delay = initial_delay
147
+
148
+ while retry_count <= max_retries:
149
+ conn = None
150
+ try:
151
+ conn = sqlite3.connect(db_path, timeout=30)
152
+ cursor = conn.cursor()
153
+ # 启用WAL模式提升多进程读写并发能力
154
+ cursor.execute("PRAGMA journal_mode=WAL")
155
+ cursor.execute("PRAGMA synchronous=NORMAL")
156
+ cursor.execute(create_table_sql)
154
157
  for i in range(0, len(data), db_insert_size):
155
158
  batch = data[i:i + db_insert_size]
156
159
  cursor.executemany(insert_sql, batch)
157
- conn.commit()
158
- except sqlite3.Error as e:
159
- logger.error(f"An sqlite3 error occurred: {e}")
160
- raise RuntimeError("An sqlite3 error occurred") from e
161
- finally:
162
- conn.close()
160
+ conn.commit()
161
+ return
162
+ except sqlite3.OperationalError as e:
163
+ if "database is locked" in str(e).lower():
164
+ retry_count += 1
165
+ if retry_count > max_retries:
166
+ logger.error(f"Database lock conflict retry attempts exhausted ({max_retries}): {e}")
167
+ raise RuntimeError(f"DB lock retry exhausted: {e}") from e
168
+
169
+ logger.warning(
170
+ f"DB lock conflict (retry {retry_count}/{max_retries}), wait {current_delay:.2f}s : {e}"
171
+ )
172
+ time.sleep(current_delay)
173
+ current_delay *= 2
174
+ continue
175
+
176
+ logger.error(f"An sqlite3 error occurred: {e}")
177
+ raise e
178
+ except sqlite3.Error as e:
179
+ logger.error(f"An sqlite3 error occurred: {e}")
180
+ raise e
181
+ except Exception as e:
182
+ logger.error(f"An unknown error occurred: {e}")
183
+ raise e
184
+ finally:
185
+ if conn:
186
+ conn.close()
163
187
 
164
188
 
165
189
  def add_table_index(db_path):
@@ -126,21 +126,25 @@ class Graph:
126
126
 
127
127
  def get_sorted_nodes(self):
128
128
  """
129
- 通过深度优先遍历graph,获得排过序的node列表
129
+ 通过深度优先遍历graph,获得排过序的node列表,使用栈实现避免超出递归深度问题
130
130
  """
131
131
  visited = set()
132
132
  order = []
133
+ stack = [(self.root, False)]
133
134
 
134
- @recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500)
135
- def visit(node):
135
+ while stack:
136
+ node, processed = stack.pop()
136
137
  if node.id in visited:
137
- return
138
- visited.add(node.id)
139
- for sub_node in node.subnodes:
140
- visit(sub_node)
141
- order.append(node)
138
+ continue
139
+ if processed:
140
+ visited.add(node.id)
141
+ order.append(node)
142
+ else:
143
+ stack.append((node, True))
144
+ for sub_node in reversed(node.subnodes):
145
+ if sub_node.id not in visited:
146
+ stack.append((sub_node, False))
142
147
 
143
- visit(self.root)
144
148
  return order
145
149
 
146
150
  def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
@@ -242,11 +242,15 @@ def _compare_graph_ranks(input_param, args, step=None):
242
242
  def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
243
243
  dump_rank_n = input_param.get('npu_path')
244
244
  dump_rank_b = input_param.get('bench_path')
245
- npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
246
- bench_ranks = sorted(check_and_return_dir_contents(dump_rank_b, Const.RANK))
245
+ npu_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_n, Const.RANK))
246
+ bench_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_b, Const.RANK))
247
247
  if npu_ranks != bench_ranks:
248
- logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
249
- raise CompareException(CompareException.INVALID_PATH_ERROR)
248
+ intersection_ranks = sort_rank_number_strings(list(set(npu_ranks) & set(bench_ranks)))
249
+ if not intersection_ranks:
250
+ logger.error('The ranks in the two runs are completely different. Unable to match the ranks.')
251
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
252
+ npu_ranks = intersection_ranks
253
+ bench_ranks = intersection_ranks
250
254
  compare_graph_results = []
251
255
  if is_real_data_compare(input_param, npu_ranks, bench_ranks):
252
256
  mp_task_dict = {}
@@ -282,12 +286,16 @@ def _compare_graph_steps(input_param, args):
282
286
  dump_step_n = input_param.get('npu_path')
283
287
  dump_step_b = input_param.get('bench_path')
284
288
 
285
- npu_steps = sorted(check_and_return_dir_contents(dump_step_n, Const.STEP))
286
- bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
289
+ npu_steps = check_and_return_dir_contents(dump_step_n, Const.STEP)
290
+ bench_steps = check_and_return_dir_contents(dump_step_b, Const.STEP)
287
291
 
288
292
  if npu_steps != bench_steps:
289
- logger.error('The number of steps in the two runs is different. Unable to match the steps.')
290
- raise CompareException(CompareException.INVALID_PATH_ERROR)
293
+ intersection_steps = sort_rank_number_strings(list(set(npu_steps) & set(bench_steps)))
294
+
295
+ if not intersection_steps:
296
+ logger.error('The steps in the two runs are completely different. Unable to match the steps.')
297
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
298
+ npu_steps = intersection_steps
291
299
 
292
300
  args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
293
301
 
@@ -355,8 +363,10 @@ def _build_graph_steps(dump_steps_path, args):
355
363
  _build_graph_ranks(dump_ranks_path, args, step)
356
364
 
357
365
 
358
- def _compare_and_export_graph(graph_task_info, input_param, args):
366
+ def _compare_and_export_graph(graph_task_info, input_param, args, step=None):
359
367
  result = _run_graph_compare(graph_task_info, input_param, args)
368
+ if step is not None:
369
+ result.step = get_step_or_rank_int(step)
360
370
  return _export_compare_graph_result(args, result)
361
371
 
362
372
 
@@ -413,7 +423,7 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
413
423
  _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
414
424
  f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
415
425
  export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
416
- args=(graph_task_info, input_param, serializable_args),
426
+ args=(graph_task_info, input_param, serializable_args, step),
417
427
  error_callback=err_call))
418
428
  export_res_list = [res.get() for res in export_res_task_list]
419
429
  if any(export_res_list):