mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
  3. msprobe/README.md +7 -5
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +105 -27
  7. msprobe/core/common/framework_adapter.py +7 -6
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/utils.py +14 -3
  10. msprobe/core/compare/find_first/analyzer.py +8 -7
  11. msprobe/core/compare/find_first/graph.py +11 -3
  12. msprobe/core/compare/find_first/utils.py +2 -1
  13. msprobe/core/compare/highlight.py +13 -6
  14. msprobe/core/compare/multiprocessing_compute.py +17 -10
  15. msprobe/core/compare/utils.py +14 -5
  16. msprobe/core/data_dump/data_collector.py +18 -21
  17. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  18. msprobe/core/data_dump/json_writer.py +18 -8
  19. msprobe/core/data_dump/scope.py +4 -6
  20. msprobe/core/hook_manager.py +37 -3
  21. msprobe/core/service.py +18 -5
  22. msprobe/core/single_save/single_comparator.py +16 -3
  23. msprobe/docs/01.installation.md +7 -5
  24. msprobe/docs/02.config_introduction.md +14 -1
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  29. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  30. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  31. msprobe/docs/19.monitor.md +2 -0
  32. msprobe/docs/21.visualization_PyTorch.md +15 -80
  33. msprobe/docs/22.visualization_MindSpore.md +20 -104
  34. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  35. msprobe/docs/25.tool_function_introduction.md +1 -0
  36. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  37. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  38. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  40. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  41. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  42. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  43. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  44. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  45. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  46. msprobe/mindspore/cell_processor.py +33 -5
  47. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  48. msprobe/mindspore/compare/utils.py +1 -2
  49. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  50. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  51. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  52. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  53. msprobe/msprobe.py +6 -4
  54. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  55. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  58. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  59. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  60. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  61. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  62. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  63. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  64. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  65. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  66. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  67. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  68. msprobe/pytorch/attl_manager.py +65 -0
  69. msprobe/pytorch/common/utils.py +22 -2
  70. msprobe/pytorch/compare/utils.py +3 -3
  71. msprobe/pytorch/debugger/debugger_config.py +10 -0
  72. msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
  73. msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
  74. msprobe/pytorch/hook_module/api_register.py +6 -1
  75. msprobe/pytorch/monitor/module_hook.py +28 -9
  76. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  77. msprobe/pytorch/pt_config.py +57 -2
  78. msprobe/pytorch/pytorch_service.py +11 -2
  79. msprobe/visualization/builder/graph_builder.py +170 -64
  80. msprobe/visualization/builder/graph_merger.py +0 -1
  81. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  82. msprobe/visualization/db_utils.py +25 -2
  83. msprobe/visualization/graph/base_node.py +0 -24
  84. msprobe/visualization/graph/graph.py +5 -14
  85. msprobe/visualization/graph_service.py +29 -53
  86. msprobe/visualization/utils.py +11 -1
  87. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  88. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  89. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  90. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -26,8 +26,7 @@ def read_npy_data(dir_path, file_name):
26
26
  return None
27
27
 
28
28
  data_path = os.path.join(dir_path, file_name)
29
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
30
- FileCheckConst.NUMPY_SUFFIX, False)
29
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX)
31
30
  data_path = path_checker.common_check()
32
31
  data_value = load_npy(data_path)
33
32
  return data_value
@@ -182,7 +182,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
182
182
  with ThreadSafe():
183
183
  instance.service.step()
184
184
  if is_graph_mode_cell_dump_allowed(instance.config):
185
- GraphModeCellDump.step()
185
+ GraphModeCellDump.step(instance.config.dump_path, instance.config.step, instance.config.task)
186
186
  if enable_dynamic_kbyk_dump and instance.config.level_ori == Const.LEVEL_L2:
187
187
  _dump_step(1)
188
188
  if cls._is_kernel_dump() and _msprobe_c:
@@ -46,9 +46,11 @@ KEY_FORWARD = CoreConst.FORWARD
46
46
  KEY_BACKWARD = CoreConst.BACKWARD
47
47
  KEY_INPUT = CoreConst.INPUT
48
48
  KEY_OUTPUT = CoreConst.OUTPUT
49
- KEY_DUMP_TENSOR_DATA = "dump_tensor_data_"
49
+ KEY_DUMP_TENSOR_DATA = "dump_tensor_data/"
50
50
  KEY_STATISTIC_CSV = "statistic.csv"
51
51
  KEY_TD_FLAG = "td_flag"
52
+ # 设置落盘文件检测超时时间
53
+ TIMEOUT = 600
52
54
  td = ops.TensorDump()
53
55
  if (ms.__version__ >= "2.5.0"):
54
56
  td_in = ops.TensorDump("in")
@@ -574,28 +576,33 @@ def generate_stack_info(path):
574
576
  logger.info(f"Stack data saved to {json_path}")
575
577
 
576
578
 
577
- def is_download_finished(directory, interval=3):
579
+ def is_download_finished(directory, save_flag):
578
580
  """
579
581
  判断指定目录在一段时间后是否有数据被下载完成
580
582
  :param directory: 指定目录的路径
581
- :param interval: 检查的时间间隔(秒),默认为 3 秒
583
+ :param save_flag: 数据落盘完成后的标志文件
582
584
  :return: 如有数据被下载完成返回 True,否则返回 False
583
585
  """
586
+ # 设定一定的延迟间隔,避免频繁进行磁盘的io读取操作
587
+ time.sleep(0.5)
588
+ logger.info("Waiting for download...")
584
589
  # 检查目录是否存在
585
590
  if not os.path.exists(directory):
586
591
  logger.warning(f"The specified directory {directory} does not exist.")
587
592
  return False
588
- initial_modification_time = os.path.getmtime(directory)
589
- time.sleep(interval)
590
- current_modification_time = os.path.getmtime(directory)
591
- # 比较初始和当前修改时间
592
- if current_modification_time > initial_modification_time:
593
- return False
594
- else:
595
- return True
593
+
594
+ # 遍历当前目录中的所有条目
595
+ for entry_path in os.listdir(directory):
596
+ if entry_path.startswith(save_flag):
597
+ return True
598
+
599
+ return False
600
+
596
601
 
602
+ def process_step(dump_path, flag_path, step, step_list):
603
+ if step not in step_list:
604
+ return
597
605
 
598
- def process(dump_path):
599
606
  if not os.path.exists(dump_path):
600
607
  logger.warning('No grap cell data is dumped.')
601
608
  create_directory(dump_path)
@@ -606,32 +613,38 @@ def process(dump_path):
606
613
  if rank_id is not None:
607
614
  rank_dir = CoreConst.RANK + str(rank_id)
608
615
 
609
- step_dir_list = os.listdir(dump_path)
610
- for step_dir in step_dir_list:
611
- step_path = os.path.join(dump_path, step_dir)
612
- rank_path = os.path.join(step_path, rank_dir)
613
- npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
614
- while True:
615
- is_finished = is_download_finished(npy_path)
616
- if not is_finished:
617
- logger.info("There is data being downloaded in the specified directory, continue checking...")
618
- else:
619
- logger.info("There is no data being downloaded in the specified directory, Stop checking.")
620
- break
621
- logger.info("==========Start processing data that has already been stored on the disk!==========")
622
- rename_filename(path=npy_path)
623
- generate_construct(npy_path)
624
- generate_dump_info(npy_path)
625
- generate_stack_info(npy_path)
626
- # 单卡场景,rank目录名称为rank
627
- if rank_id is None:
628
- new_rank_path = os.path.join(step_path, CoreConst.RANK)
629
- try:
630
- move_directory(rank_path, new_rank_path)
631
- logger.info(f"Directory was successfully renamed to: {new_rank_path}")
632
- except Exception as e:
633
- logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
634
- logger.info("==========JSON file generation completed!==========")
616
+ step_dir = CoreConst.STEP + str(step)
617
+
618
+ step_path = os.path.join(dump_path, step_dir)
619
+ rank_path = os.path.join(step_path, rank_dir)
620
+ npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
621
+ save_finish_flag = f"step_{step}"
622
+ start_time = time.time()
623
+ while True:
624
+ is_finished = is_download_finished(flag_path, save_finish_flag)
625
+ if not is_finished:
626
+ logger.info("There is data being downloaded in the specified directory, continue checking...")
627
+ else:
628
+ logger.info("There is no data being downloaded in the specified directory, Stop checking.")
629
+ break
630
+ elapsed_time = time.time() - start_time
631
+ if elapsed_time > TIMEOUT:
632
+ logger.error(f"Check timed out after {TIMEOUT} seconds. Exiting.")
633
+ return
634
+ logger.info(f"==========Start processing step_{step}'s data that has already been stored on the disk!==========")
635
+ rename_filename(path=npy_path)
636
+ generate_construct(npy_path)
637
+ generate_dump_info(npy_path)
638
+ generate_stack_info(npy_path)
639
+ # 单卡场景,rank目录名称为rank
640
+ if rank_id is None:
641
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
642
+ try:
643
+ move_directory(rank_path, new_rank_path)
644
+ logger.info(f"Directory was successfully renamed to: {new_rank_path}")
645
+ except Exception as e:
646
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
647
+ logger.info(f"==========Step_{step}'s JSON file generation completed!==========")
635
648
 
636
649
 
637
650
  # 删除csv文件中每行数据最后面的逗号
@@ -689,7 +702,10 @@ def merge_file(dump_path, rank_dir, file_dict):
689
702
  " and the index is out of bounds.")
690
703
 
691
704
 
692
- def process_statistics(dump_path):
705
+ def process_statistics_step(dump_path, step, step_list):
706
+ if step_list and step not in step_list:
707
+ return
708
+
693
709
  if not os.path.exists(dump_path):
694
710
  logger.warning('No grap cell data is dumped.')
695
711
  create_directory(dump_path)
@@ -723,25 +739,24 @@ def process_statistics(dump_path):
723
739
 
724
740
  rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '')
725
741
  dir_list = os.listdir(dump_path)
726
- step_dir_list = [d for d in dir_list if d.startswith(CoreConst.STEP)]
727
- for step_dir in step_dir_list:
728
- step_path = os.path.join(dump_path, step_dir)
729
- rank_path = os.path.join(step_path, rank_dir)
730
- csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
731
- logger.info("==========Start processing data csv!==========")
732
- generate_construct(csv_path)
733
- generate_dump_info(csv_path)
734
- generate_stack_info(csv_path)
735
- remove_path(rank_path_kbk)
736
- # 单卡场景,rank目录名称为rank
737
- if rank_id is None:
738
- new_rank_path = os.path.join(step_path, CoreConst.RANK)
739
- try:
740
- move_directory(rank_path, new_rank_path)
741
- logger.info(f"Directory was successfully renamed to: {new_rank_path}")
742
- except Exception as e:
743
- logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
744
- logger.info("==========JSON file generation completed!==========")
742
+ step_dir = CoreConst.STEP + str(step)
743
+ step_path = os.path.join(dump_path, step_dir)
744
+ rank_path = os.path.join(step_path, rank_dir)
745
+ csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
746
+ logger.info("==========Start processing data csv!==========")
747
+ generate_construct(csv_path)
748
+ generate_dump_info(csv_path)
749
+ generate_stack_info(csv_path)
750
+ remove_path(rank_path_kbk)
751
+ # 单卡场景,rank目录名称为rank
752
+ if rank_id is None:
753
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
754
+ try:
755
+ move_directory(rank_path, new_rank_path)
756
+ logger.info(f"Directory was successfully renamed to: {new_rank_path}")
757
+ except Exception as e:
758
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
759
+ logger.info("==========JSON file generation completed!==========")
745
760
 
746
761
 
747
762
  def get_yaml_keys(yaml_data):
@@ -922,7 +937,3 @@ def start(config: CellDumpConfig):
922
937
  cell.data_mode = data_mode
923
938
 
924
939
  logger.info("==========The cell_dump_process_start phase is Finished!==========")
925
- if dump_task == CoreConst.TENSOR:
926
- atexit.register(process, dump_path=dump_path)
927
- if dump_task == CoreConst.STATISTICS:
928
- atexit.register(process_statistics, dump_path=dump_path)
@@ -14,7 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
-
17
+ import glob
18
+ import tempfile
18
19
  import mindspore as ms
19
20
  from mindspore import hal, ops, Tensor
20
21
  from mindspore.ops.primitive import _run_op
@@ -28,6 +29,7 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
28
29
  import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
29
30
 
30
31
  tensordump_flag = True
32
+ DEFAULT_RANK_DIR = "rank0"
31
33
  try:
32
34
  from mindspore._c_expression import _tensordump_set_step
33
35
  except ImportError:
@@ -41,8 +43,6 @@ except ImportError:
41
43
 
42
44
 
43
45
  class GraphModeCellDump:
44
- task = CoreConst.STATISTICS
45
-
46
46
  def __init__(self, config: DebuggerConfig, model, strict=True):
47
47
  self.net = model
48
48
  self.white_list = []
@@ -55,29 +55,40 @@ class GraphModeCellDump:
55
55
  self.list = config.list
56
56
  self.data_mode = config.data_mode
57
57
  self.file_format = config.file_format
58
- GraphModeCellDump.task = config.task
59
58
  self.summary_mode = config.summary_mode
59
+ self.task = config.task
60
60
  self.check_config(strict)
61
61
  self.set_step()
62
62
 
63
63
  @staticmethod
64
- def step():
64
+ def step(dump_path, step_list, task):
65
65
  # 更新TensorDump Step
66
- if GraphModeCellDump.task == CoreConst.TENSOR:
66
+ if task == CoreConst.TENSOR:
67
67
  hal.synchronize()
68
68
  temp_tensor = ms.Tensor([1], dtype=ms.float32)
69
- step_flag = "<tensordump-update-step>"
70
- _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
71
- ops.tensordump(step_flag, temp_tensor)
69
+ rank_id = os.environ.get('RANK_ID')
70
+ rank_dir = DEFAULT_RANK_DIR
71
+
72
+ if rank_id is not None:
73
+ rank_dir = CoreConst.RANK + str(rank_id)
74
+
75
+ with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
76
+ save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
77
+ _run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
78
+ step_flag = "<tensordump-update-step>"
79
+ _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
80
+ ops.tensordump(step_flag, temp_tensor)
81
+ cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
72
82
 
73
83
  # 更新静态图KBK dump的step数
74
- if GraphModeCellDump.task == CoreConst.STATISTICS:
84
+ if task == CoreConst.STATISTICS:
75
85
  if not graph_step_flag:
76
86
  raise Exception(
77
87
  "Importing _dump_step failed, "
78
88
  "please use the latest version package of MindSpore."
79
89
  )
80
90
  _dump_step(1)
91
+ cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
81
92
 
82
93
  def check_config(self, strict):
83
94
  if not self.net:
@@ -203,10 +203,12 @@ class MindsporeHookManager(BaseHookManager):
203
203
  return
204
204
 
205
205
  with ThreadSafe():
206
+ original_state = self.ensure_gc_enabled()
206
207
  BaseHookManager.inner_switch[tid] = True
207
208
  module_input = ModuleBackwardInputs(grad_input=grad_input)
208
209
  self.data_collector.update_api_or_module_name(full_name)
209
210
  self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input)
210
211
  BaseHookManager.inner_switch[tid] = False
212
+ self.restore_gc_state(original_state)
211
213
 
212
214
  return backward_pre_hook
msprobe/msprobe.py CHANGED
@@ -14,16 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import argparse
17
- import sys
18
17
  import importlib.util
18
+ import sys
19
19
 
20
20
  from msprobe.core.common.const import Const
21
+ from msprobe.core.common.file_utils import root_privilege_warning
21
22
  from msprobe.core.common.log import logger
22
- from msprobe.core.compare.utils import _compare_parser
23
23
  from msprobe.core.compare.compare_cli import compare_cli
24
24
  from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
25
- from msprobe.core.config_check.config_check_cli import _config_checking_parser, \
26
- _run_config_checking_command
25
+ from msprobe.core.compare.utils import _compare_parser
26
+ from msprobe.core.config_check.config_check_cli import _config_checking_parser, _run_config_checking_command
27
27
 
28
28
 
29
29
  def is_module_available(module_name):
@@ -64,6 +64,8 @@ def main():
64
64
  if len(sys.argv) < 4:
65
65
  parser.print_help()
66
66
  sys.exit(0)
67
+
68
+ root_privilege_warning()
67
69
  framework_args = parser.parse_args(sys.argv[1:3])
68
70
  if framework_args.framework == Const.PT_FRAMEWORK:
69
71
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
@@ -24,7 +24,8 @@ from msprobe.pytorch.pt_config import RunUTConfig
24
24
 
25
25
  RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
26
  'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
- 'black_list', 'error_data_path'])
27
+ 'black_list', 'error_data_path', 'online_config'])
28
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
28
29
 
29
30
 
30
31
  class Config:
@@ -45,7 +46,13 @@ class Config:
45
46
  'white_list': list,
46
47
  'black_list': list,
47
48
  'error_data_path': str,
48
- 'precision': int
49
+ 'precision': int,
50
+ 'is_online': bool,
51
+ 'nfs_path': str,
52
+ 'host': str,
53
+ 'port': int,
54
+ 'rank_list': list,
55
+ 'tls_path': str
49
56
  }
50
57
  if key not in validators:
51
58
  raise ValueError(f"{key} must be one of {validators.keys()}")
@@ -61,6 +68,10 @@ class Config:
61
68
  RunUTConfig.check_filter_list_config(key, value)
62
69
  if key == 'error_data_path':
63
70
  RunUTConfig.check_error_data_path_config(value)
71
+ if key == 'nfs_path':
72
+ RunUTConfig.check_nfs_path_config(value)
73
+ if key == 'tls_path':
74
+ RunUTConfig.check_tls_path_config(value)
64
75
  return value
65
76
 
66
77
 
@@ -74,6 +85,12 @@ class CheckerConfig:
74
85
  self.white_list = msCheckerConfig.white_list
75
86
  self.black_list = msCheckerConfig.black_list
76
87
  self.error_data_path = msCheckerConfig.error_data_path
88
+ self.is_online = msCheckerConfig.is_online
89
+ self.nfs_path = msCheckerConfig.nfs_path
90
+ self.host = msCheckerConfig.host
91
+ self.port = msCheckerConfig.port
92
+ self.rank_list = msCheckerConfig.rank_list
93
+ self.tls_path = msCheckerConfig.tls_path
77
94
 
78
95
  if task_config:
79
96
  self.load_config(task_config)
@@ -82,7 +99,22 @@ class CheckerConfig:
82
99
  self.white_list = task_config.white_list
83
100
  self.black_list = task_config.black_list
84
101
  self.error_data_path = task_config.error_data_path
102
+ self.is_online = task_config.is_online
103
+ self.nfs_path = task_config.nfs_path
104
+ self.host = task_config.host
105
+ self.port = task_config.port
106
+ self.rank_list = task_config.rank_list
107
+ self.tls_path = task_config.tls_path
85
108
 
109
+ def get_online_config(self):
110
+ return OnlineConfig(
111
+ is_online=self.is_online,
112
+ nfs_path=self.nfs_path,
113
+ host=self.host,
114
+ port=self.port,
115
+ rank_list=self.rank_list,
116
+ tls_path=self.tls_path
117
+ )
86
118
 
87
119
  def get_run_ut_config(self, **config_params):
88
120
  return RunUtConfig(
@@ -95,5 +127,6 @@ class CheckerConfig:
95
127
  real_data_path=config_params.get('real_data_path'),
96
128
  white_list=self.white_list.copy() if self.white_list else [],
97
129
  black_list=self.black_list.copy() if self.black_list else [],
98
- error_data_path=config_params.get('error_data_path')
130
+ error_data_path=config_params.get('error_data_path'),
131
+ online_config=self.get_online_config()
99
132
  )
@@ -117,6 +117,30 @@ def api_precision_compare(config):
117
117
  change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
118
118
 
119
119
 
120
+ def online_api_precision_compare(online_config):
121
+ rank = online_config.rank
122
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
123
+ "_rank*.csv", f"_rank{rank}.csv")
124
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
125
+ "_rank*.csv", f"_rank{rank}.csv")
126
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
127
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
128
+ if not os.path.exists(result_csv_path):
129
+ write_csv(result_csv_title, result_csv_path)
130
+ if not os.path.exists(details_csv_path):
131
+ write_csv(detail_csv_title, details_csv_path)
132
+ config = CompareConfig("", "", result_csv_path, details_csv_path)
133
+ try:
134
+ npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
135
+ check_csv_columns(npu_data.columns, "npu_csv")
136
+ check_csv_columns(gpu_data.columns, "gpu_csv")
137
+ analyse_csv(npu_data, gpu_data, config)
138
+ except Exception as err:
139
+ logger.error(f"Online api precision compare Error: {str(err)}")
140
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
141
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
142
+
143
+
120
144
  def analyse_csv(npu_data, gpu_data, config):
121
145
  forward_status, backward_status = [], []
122
146
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
@@ -66,6 +66,13 @@ class Comparator:
66
66
  self.save_path_list = [result_csv_path]
67
67
  self.detail_save_path_list = [details_csv_path]
68
68
 
69
+ if config and config.online_config.is_online:
70
+ self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
71
+ self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
72
+ self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
73
+ self.detail_save_path_list = \
74
+ [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
75
+
69
76
  self.registry = self._register_compare_func()
70
77
 
71
78
  if not is_continue_run_ut:
@@ -238,8 +245,9 @@ class Comparator:
238
245
  self.write_detail_csv(args)
239
246
 
240
247
 
241
- def compare_output(self, full_api_name, data_info):
248
+ def compare_output(self, full_api_name, data_info, is_online=False):
242
249
  """Get compare result and write to result and detail csv.
250
+ is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
243
251
  """
244
252
  _, api_name = extract_basic_api_segments(full_api_name)
245
253
  if not api_name:
@@ -272,7 +280,9 @@ class Comparator:
272
280
  fwd_compare_alg_results,
273
281
  bwd_compare_alg_results,
274
282
  data_info.rank)
275
-
283
+ if is_online:
284
+ # get run_ut compare detail
285
+ return self._get_run_ut_detail(result_info)
276
286
  self.record_results(result_info)
277
287
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
278
288
  or bwd_success_status == CompareConst.SPACE
@@ -2,4 +2,9 @@ white_list: []
2
2
  black_list: []
3
3
  error_data_path: './'
4
4
  precision: 14
5
-
5
+ is_online: False
6
+ nfs_path: ""
7
+ host: ""
8
+ port: -1
9
+ rank_list: [0]
10
+ tls_path: "./"
@@ -47,7 +47,7 @@ API_INFO = 2
47
47
  FOUR_SEGMENT = 4
48
48
  FIVE_SEGMENT = 5
49
49
  DATA_NAME = "data_name"
50
- API_MAX_LENGTH = 30
50
+ API_MAX_LENGTH = 300
51
51
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
52
  DATAMODE_LIST = ["random_data", "real_data"]
53
53
  ITER_MAX_TIMES = 1000