mindstudio-probe 8.3.1__py3-none-any.whl → 8.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mindstudio-probe
3
- Version: 8.3.1
3
+ Version: 8.3.3
4
4
  Summary: Ascend Probe Utils
5
5
  Home-page: https://gitcode.com/Ascend/mstt/tree/master/debug/accuracy_tools/msprobe
6
6
  Author: Ascend Team
@@ -32,7 +32,7 @@ msprobe/core/compare/highlight.py,sha256=iNgkVAUSfJlxKt0NC8A78XpWI5aNqJrxyCRWYVG
32
32
  msprobe/core/compare/ms_to_pt_api.yaml,sha256=NGzy6_yIArM6V0zYsW3sg3KLMJe0sr9ljKiIcHujJus,13203
33
33
  msprobe/core/compare/multiprocessing_compute.py,sha256=Yy7eNUyP4u9HHn7BLAK6F0FddR1JkY01qUNllPI30yE,14262
34
34
  msprobe/core/compare/npy_compare.py,sha256=aCylTUJuLUEsqqm-KySnY24egh5b36rTFh_dXfaJx84,12712
35
- msprobe/core/compare/utils.py,sha256=8vWFoqgqBY8XRc7eDOX2Lzaln26aPmaLF8Vs_xBQovk,35416
35
+ msprobe/core/compare/utils.py,sha256=Aklo1fbe4J22O7JnYFxstaCP6j191gKYyhyx4QX3dm0,36156
36
36
  msprobe/core/compare/diff_analyze/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml,sha256=77K_1RM8ICVuDSNVTFIRdeBu15HyzR2Kw9YjhHaoaT0,193
38
38
  msprobe/core/compare/diff_analyze/first_diff_analyze.py,sha256=zXfaVYuMwA-y0upjQaycuQ2Q7dc2kqH8kAgWcWqXkVs,4755
@@ -101,11 +101,11 @@ msprobe/docs/01.installation.md,sha256=I4iA214liJglmn7T6moGlipnWNy2y1aKNWAoHsU0B
101
101
  msprobe/docs/02.config_introduction.md,sha256=EVzpOM38ZsFde3PTQkR2LYepix-lpLbYIr9uKZuBNtU,32272
102
102
  msprobe/docs/03.config_examples.md,sha256=lHv_HLaCxk1nlSPIUF7wQVO0RR6_Q5b9XrSp96fr1FE,4408
103
103
  msprobe/docs/04.kernel_dump_PyTorch.md,sha256=1qXNBrdkO2r8FJjHObr1r2XGy7gIzhvZ3DWC0JjDPng,3034
104
- msprobe/docs/05.data_dump_PyTorch.md,sha256=0ZLkESAFa_HBlo8nndc7cWXQ0O1l7BuNCCy7eARTcm8,22242
104
+ msprobe/docs/05.data_dump_PyTorch.md,sha256=UnswISCnklt0IYPVBvtnclBb-gq0an2ozoWoq-jX0v4,22243
105
105
  msprobe/docs/06.data_dump_MindSpore.md,sha256=pweNRLRxNy1cu8h5cgmEf9auOBuKN3-Ezen5Et7a1s8,29284
106
- msprobe/docs/07.accuracy_checker_PyTorch.md,sha256=tLHA1eD0ldNv_4uQKArN-HzKcPdQbQighhK9VKuhuxQ,30993
106
+ msprobe/docs/07.accuracy_checker_PyTorch.md,sha256=2oLP43_0cgvF1rLza3xQGST9WvKfFjBocLLSA8tZG6E,31006
107
107
  msprobe/docs/09.accuracy_checker_MindSpore.md,sha256=XKEnEgqPJKZxL7JL3evZeIAa_LbMh8DR7sILqJ1camw,11320
108
- msprobe/docs/10.accuracy_compare_PyTorch.md,sha256=3H_fHJBMRweqAxEOT6trHSQPAMjP3xD9G3T_PIJPtVY,41611
108
+ msprobe/docs/10.accuracy_compare_PyTorch.md,sha256=xZU4cKF9JNFVQKxsKOg7O5hx03gqQjFYrOMGWcp8pd8,41635
109
109
  msprobe/docs/11.accuracy_compare_MindSpore.md,sha256=9zalk8EiYA_ex5kkiGEIn8yaNygLmJPplUoUEPs-C5w,40792
110
110
  msprobe/docs/12.overflow_check_PyTorch.md,sha256=VE76J_rFAFm0pw_bQEnKIV4FdXek4x7cenxTQQz2FK0,3793
111
111
  msprobe/docs/13.overflow_check_MindSpore.md,sha256=G67GjHo2VERAKsr1_uX0m67eNuG8E9aN6tmZiO1c_zg,2799
@@ -114,7 +114,7 @@ msprobe/docs/15.free_benchmarking_PyTorch.md,sha256=6pIpLXx5lUu45_kdL31rhj9zV3Ew
114
114
  msprobe/docs/16.free_benchmarking_MindSpore.md,sha256=swCOrnBSzU6Q5I0AHVwi2r0JfKp3VE1DgNHripyE01M,8195
115
115
  msprobe/docs/17.grad_probe.md,sha256=9g1aq6FettgpvzBxKj5C5W8bTsENPtt2itYaNOLcBk8,9624
116
116
  msprobe/docs/18.online_dispatch.md,sha256=Ae9ONIXF3wA2u0tikuDxV0nea_n5TIGug-PKTI9c7Ws,4170
117
- msprobe/docs/19.monitor.md,sha256=ogaecwFA5vpEyEEKuXaWmbaoClnPdtRedYhntppTAYo,57192
117
+ msprobe/docs/19.monitor.md,sha256=TDfLjTEoN4nEUikeqici0MdAtLmuaXsYdGHjWhhNPl4,57194
118
118
  msprobe/docs/20.monitor_performance_baseline.md,sha256=t-aM1s7BqEE8ls47gL1wmipkhdmB6CSPI90-r4abyXY,3672
119
119
  msprobe/docs/21.visualization_PyTorch.md,sha256=lEt72aqsT5voRk38vAapXSJA_nzbu3DWq3-mEYKZowU,27678
120
120
  msprobe/docs/22.visualization_MindSpore.md,sha256=jtMgasjacyNEevhmrMT5Y8IOsAgEuDzBN7vvZYBJl7s,26605
@@ -319,12 +319,12 @@ msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py,sha256=ZzVF13owq
319
319
  msprobe/mindspore/overflow_check/overflow_check_tool_factory.py,sha256=6JqOsKCaNZbUi_pXCv6uGqHQX4yhhauqoPziyI97FUw,1909
320
320
  msprobe/nan_analyze/__init__.py,sha256=S8S4SIEKDrF4pk8alDPapKUdlTHzok2jFc5kqk-JIL0,617
321
321
  msprobe/nan_analyze/analyzer.py,sha256=EEMwQWm33JgXgvLudoMRkIYDQpQaqViFpl6BFMIrOtM,11667
322
- msprobe/nan_analyze/graph.py,sha256=uSEYUC4I1KhGPUyi3E1BblNNpxUY3fjpf_sNnaaY7w4,8085
322
+ msprobe/nan_analyze/graph.py,sha256=PwuJL62Sgg7ANZKpCPMwPxx41cxvUiXZMDkLgkTUUI8,8164
323
323
  msprobe/nan_analyze/utils.py,sha256=jwpnHcfluzBlytAcUuyAotj0Hw3zmi7e8ncf7rME2JU,7666
324
324
  msprobe/pytorch/__init__.py,sha256=qIvhnAk61oSpvPU0QI0YAC4zyLKYyOcfNzdXJxN7Klo,1035
325
325
  msprobe/pytorch/function_factory.py,sha256=Fi4w0zfO64Sd2IU9z45mBbDgS8k_CQEiZ9vRpC_TkVk,4031
326
326
  msprobe/pytorch/pt_config.py,sha256=1A6RhK6BBdMxXdYfLIcvtkOFUSXEaRMuNvaK6stjl7o,12284
327
- msprobe/pytorch/pytorch_service.py,sha256=f_ItlE28Ruhr1GgMjHA4iqg1sAqRthi129gfsFZFVDQ,2590
327
+ msprobe/pytorch/pytorch_service.py,sha256=NJiKpMm9lSwzXCZmHmNWnaaM7xuHCcDcxDDQX9MMLl4,2664
328
328
  msprobe/pytorch/api_accuracy_checker/.keep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
329
329
  msprobe/pytorch/api_accuracy_checker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
330
330
  msprobe/pytorch/api_accuracy_checker/config.yaml,sha256=JDD7boERYCh7vd5x74ZcHXmS8afthdj6-pYGIqHinDY,67
@@ -427,7 +427,7 @@ msprobe/pytorch/grad_probe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
427
427
  msprobe/pytorch/grad_probe/grad_monitor.py,sha256=xb3M1XguMxVsSgZuYwyB08tz_jno1-43cURs2pG1FYU,4910
428
428
  msprobe/pytorch/grad_probe/grad_stat_csv.py,sha256=k0YesT0Jy-2LsUlSYEDt8eRuhhfFhpSGv4ImtE74NH0,4940
429
429
  msprobe/pytorch/hook_module/__init__.py,sha256=ktS1W-6hBa1q8wF-DGz8kFsfVhiIPWBOj6C3kSatgno,679
430
- msprobe/pytorch/hook_module/api_register.py,sha256=Vfdh-t0zBWinbzuXFqZ4WCSTTKexOFpaGAuEz2gUMgU,7313
430
+ msprobe/pytorch/hook_module/api_register.py,sha256=nwoQTO6EiOnuoEzWVxkX1x6uGii7qj6oYQIff13HubI,7582
431
431
  msprobe/pytorch/hook_module/hook_module.py,sha256=I1i_vISgVcLPd4uEYMMfr7Xfz76c3ASO0-XOcIoKxVA,3738
432
432
  msprobe/pytorch/hook_module/pt_hook_manager.py,sha256=x77ZU_YNYs5PB9snbopxSgyBmDNBEDQD5JeK1bcano4,5199
433
433
  msprobe/pytorch/hook_module/register_optimizer_hook.py,sha256=3JS0VqHAswQ7SlGXQ9gS9o4kGQYZflFAwOrJhUpB5gk,2539
@@ -467,8 +467,8 @@ msprobe/pytorch/parse_tool/lib/parse_tool.py,sha256=nXX1b45o1AhpIR4oG1VpN56Jh02j
467
467
  msprobe/pytorch/parse_tool/lib/utils.py,sha256=c4tyO_QeeqdB5ysJKXcckfVshV5HuV9j7_Hr-a3cB6s,11718
468
468
  msprobe/pytorch/parse_tool/lib/visualization.py,sha256=_M1R4kiyc3zdxyoHRqFoxMEfm4klQtscXsOLBevWnEk,4071
469
469
  msprobe/visualization/__init__.py,sha256=rEvJTAx-jTGHBM7-bB6VZ7fwfWwYA961M1eZJj4sRcY,622
470
- msprobe/visualization/db_utils.py,sha256=GhazcnqNQ27n-kpfxcuW8nK83_uC336S3-C0zpoNZ2I,9449
471
- msprobe/visualization/graph_service.py,sha256=_GfTbRtutN5Yk25cvV9kzD5kVWrgcHD5t6gkSbX5P5o,26302
470
+ msprobe/visualization/db_utils.py,sha256=R_sRqrAohVBimyVV5ZCvTl8joB7K30wAp7kLGlrQp7g,10356
471
+ msprobe/visualization/graph_service.py,sha256=xz4KjeBQP9ycEPTmOMobJT-uglPQP7spXDCZ9AxIb_M,26812
472
472
  msprobe/visualization/utils.py,sha256=P5Ds9uMRoMPRNft8XTSiqDrGAbEG7bQbIZD6qJy3Lvs,12620
473
473
  msprobe/visualization/builder/__init__.py,sha256=rEvJTAx-jTGHBM7-bB6VZ7fwfWwYA961M1eZJj4sRcY,622
474
474
  msprobe/visualization/builder/graph_builder.py,sha256=yL7rHq4F94zWYBYJCyQvY3gLwzhLWnV6oLsVfUZ8AZo,20535
@@ -483,9 +483,9 @@ msprobe/visualization/graph/distributed_analyzer.py,sha256=IVPaA6Z3uZxuOob7fibgc
483
483
  msprobe/visualization/graph/graph.py,sha256=SFf_x4a3e5ejUe6PntPeEqnGCPaR6gUWpak6w_hyIBE,8776
484
484
  msprobe/visualization/graph/node_colors.py,sha256=7LpurTuE3edKGilwLVsXD7Ue4bMT7Maz63Udq_6CADM,4504
485
485
  msprobe/visualization/graph/node_op.py,sha256=qkbw3ZJkKGQH071C1CLIXi4kavvsLxBTzt0KXncz6fY,1368
486
- mindstudio_probe-8.3.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
487
- mindstudio_probe-8.3.1.dist-info/METADATA,sha256=r3DZB8reM8qZOOrmbvMGNGFRKXQaH5y86T4GWCcgQ_M,1437
488
- mindstudio_probe-8.3.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
489
- mindstudio_probe-8.3.1.dist-info/entry_points.txt,sha256=4ob3a9L018EBZFdlfgMW1lbgeIOhc4F-HCR8gBksaCQ,49
490
- mindstudio_probe-8.3.1.dist-info/top_level.txt,sha256=LxFEFqelENSyWmRtocCiEUF04IE8aZvwTl7ADB598Tk,8
491
- mindstudio_probe-8.3.1.dist-info/RECORD,,
486
+ mindstudio_probe-8.3.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
487
+ mindstudio_probe-8.3.3.dist-info/METADATA,sha256=zb5GEQTdL6OI9SvkeSJDwFk9h_gNNcC7GtrA4pOKseE,1437
488
+ mindstudio_probe-8.3.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
489
+ mindstudio_probe-8.3.3.dist-info/entry_points.txt,sha256=4ob3a9L018EBZFdlfgMW1lbgeIOhc4F-HCR8gBksaCQ,49
490
+ mindstudio_probe-8.3.3.dist-info/top_level.txt,sha256=LxFEFqelENSyWmRtocCiEUF04IE8aZvwTl7ADB598Tk,8
491
+ mindstudio_probe-8.3.3.dist-info/RECORD,,
@@ -789,23 +789,43 @@ def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare
789
789
 
790
790
  npu_ranks, bench_ranks = get_sorted_ranks(npu_dump_dir, bench_dump_dir)
791
791
 
792
- # 统计量、md5比对
793
- pre_check_dump_path = os.path.join(npu_dump_dir, npu_ranks[0], 'dump.json') if npu_ranks else ''
792
+ # ------------------预载rank0的json用于判断是什么类型dump数据------------------
793
+ # 判断是否存在dump.json或debug.json
794
+ if npu_ranks:
795
+ dir_path = os.path.join(npu_dump_dir, npu_ranks[0])
796
+ dump_file = os.path.join(dir_path, 'dump.json')
797
+ debug_file = os.path.join(dir_path, 'debug.json')
798
+
799
+ # 确定pre_check_dump_path
800
+ if os.path.exists(dump_file):
801
+ pre_check_dump_path = dump_file
802
+ elif os.path.exists(debug_file):
803
+ pre_check_dump_path = debug_file
804
+ else:
805
+ pre_check_dump_path = ''
806
+ else:
807
+ pre_check_dump_path = ''
808
+
809
+ # 如果pre_check_dump_path为空,直接返回
794
810
  if not pre_check_dump_path:
795
811
  return
812
+
796
813
  dump_data = load_json(pre_check_dump_path)
814
+
815
+ # ------------------统计量、md5比对------------------
797
816
  if dump_data.get('task') == Const.STATISTICS:
798
817
  # dump数据为统计量或md5时,多进程加速比对
799
818
  input_param_nr_list = []
800
819
  for nr, br in zip(npu_ranks, bench_ranks):
801
- input_param, skip = extract_compare_param(Const.DUMP_JSON_FILE)
802
- if not skip:
803
- input_param_nr_list.append((input_param, nr))
820
+ for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]:
821
+ input_param, skip = extract_compare_param(file_type)
822
+ if not skip:
823
+ input_param_nr_list.append((input_param, nr))
804
824
  func_args = (compare_func, input_param_nr_list, output_path, kwargs)
805
825
  multi_statistics_compare(multi_ranks_compare, func_args)
806
826
  return
807
827
 
808
- # 真实数据比对
828
+ # ------------------真实数据比对------------------
809
829
  for nr, br in zip(npu_ranks, bench_ranks):
810
830
  for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]:
811
831
  input_param, skip = extract_compare_param(file_type)
@@ -445,7 +445,7 @@ seed_all()
445
445
  debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
446
446
  # 模型定义及初始化等操作
447
447
  prompts = ["Hello, my name is"]
448
- sampling_params = SamplingParams(temprature=0.8, top_p=0.95)
448
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
449
449
  llm = LLM(model='...')
450
450
  model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
451
451
  # 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次
@@ -34,17 +34,17 @@ run_ut 预检操作包括以下两种方式:
34
34
  msprobe -f pytorch run_ut -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json
35
35
  ```
36
36
 
37
- | 参数名称 | 解释 | 是否必选 |
38
- |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- |
39
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
40
- | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
41
- | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
42
- | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
43
- | -j 或 --jit_compile | 开启 jit 编译。 | 否 |
44
- | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 |
45
- | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 |
46
- | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 |
47
- | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 |
37
+ | 参数名称 | 解释 | 是否必选 |
38
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- |
39
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
40
+ | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
41
+ | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
42
+ | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
43
+ | -j 或 --jit_compile | 开启 jit 编译。 | 否 |
44
+ | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 |
45
+ | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 |
46
+ | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 |
47
+ | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 |
48
48
 
49
49
  run_ut 执行结果包括 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 两个文件。`accuracy_checking_result_{timestamp}.csv` 属于 API 级,标明每个 API 是否通过测试。建议用户先查看 `accuracy_checking_result_{timestamp}.csv` 文件,对于其中没有通过测试的或者特定感兴趣的 API,根据其 API name 字段在 `accuracy_checking_details_{timestamp}.csv` 中查询其各个输出的达标情况以及比较指标。详细介绍请参见[ 4 预检结果](#4-预检结果)。
50
50
 
@@ -104,7 +104,7 @@ msprobe -f pytorch multi_run_ut -api_info ./dump_path/step{step_number}/rank{ran
104
104
 
105
105
  | 参数名称 | 解释 | 是否必选 |
106
106
  | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- |
107
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
107
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
108
108
  | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
109
109
  | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
110
110
  | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
@@ -216,7 +216,7 @@ msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_de
216
216
 
217
217
  | 参数名称 | 说明 | 是否必选 |
218
218
  |-----------------------| ------------- | -------- |
219
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
219
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
220
220
  | -npu 或 --npu_csv_path | NPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 |
221
221
  | -gpu 或 --gpu_csv_path | GPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 |
222
222
  | -o 或 --out_path | 指定 api_precision_compare.py 执行结果存盘路径,默认为当前目录。 | 否 |
@@ -53,15 +53,15 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s
53
53
 
54
54
  | 参数名 | 说明 | 是否必选 |
55
55
  |---------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- |
56
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
56
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
57
57
  | -i 或 --input_path | 指定[比对文件](#51-比对文件),str 类型。 | 是 |
58
58
  | -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。<br>提示:output目录下与结果件同名文件将被删除覆盖。 | 否 |
59
59
  | -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#51-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 |
60
60
  | -c 或 --compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 |
61
61
  | -f 或 --fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 |
62
62
  | -hl 或 --highlight | 高亮颜色标记。开启后,比对结果件中通过红色或黄色标记精度可疑API或模块。通过直接配置该参数开启,默认未配置,表示关闭。 开启高亮颜色标记后,比对性能降低,如果比对结果行数超出excel单页限制,程序强制关闭高亮颜色标记。 | 否 |
63
- | -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#52-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#51-比对文件)的单卡场景示例。 | 否 |
64
- | -da或--diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 |
63
+ | -dm 或 --data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#52-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#51-比对文件)的单卡场景示例。 | 否 |
64
+ | -da 或 --diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 |
65
65
 
66
66
  #### 2.1.2 整网比对场景
67
67
 
@@ -396,7 +396,7 @@ msprobe -f pytorch merge_result -i ./input_dir -o ./output_dir -config ./config.
396
396
 
397
397
  | 参数名 | 说明 | 是否必选 |
398
398
  | --------------------- |-------------------------------------------------------------------------------------------------------------------| -------- |
399
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
399
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
400
400
  | -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 |
401
401
  | -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。<br>提示:output目录下与结果件同名文件将被删除覆盖。 | 是 |
402
402
  | -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。<br>yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 |
@@ -527,7 +527,7 @@ input_args、input_kwargs和output使用统一的命名规则,当值是list类
527
527
  "input_args": [
528
528
  {
529
529
  "type": "torch.Tensor",
530
- "dytpe": "torch_float32",
530
+ "dtype": "torch_float32",
531
531
  "shape": [
532
532
  1,
533
533
  64,
@@ -604,7 +604,7 @@ output是list,长度为1,第0项后面是Tensor,命名结束;按照顺
604
604
  ```
605
605
  Functional.max_pool2d.0.forward.output.0
606
606
  ```
607
- 综上,生成的的op_name为
607
+ 综上,生成的op_name为
608
608
  ```
609
609
  Functional.max_pool2d.0.forward.input.0
610
610
  Functional.max_pool2d.0.forward.input.1
@@ -24,7 +24,7 @@
24
24
  | [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore |
25
25
  | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore |
26
26
  | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch |
27
- | [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch |
27
+ | [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch、MindSpore |
28
28
  | [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore |
29
29
  | [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore |
30
30
  | [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore |
@@ -37,9 +37,9 @@
37
37
  推荐使用方式:权重梯度的监控性能损耗小(20B dense模型全量权重梯度监控,时间增加<1%,内存增加<1%),可以长期开启。激活值监控性能损耗大,在必要时开启或者仅监控部分。
38
38
 
39
39
  ### 工具使能
40
- 在实际训练代码中找到模型、优化器定义的位置,使能monitor工具,通过配置文件(json)控制工具行为。如下分别为Pytorch场景和MindSpore场景下的使能方式。
40
+ 在实际训练代码中找到模型、优化器定义的位置,使能monitor工具,通过配置文件(json)控制工具行为。如下分别为PyTorch场景和MindSpore场景下的使能方式。
41
41
 
42
- - Pytorch使能方式:
42
+ - PyTorch使能方式:
43
43
  ```python
44
44
  # Megatron-LM(core_r0.6.0) training.py
45
45
  model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
@@ -112,7 +112,7 @@ monitor.set_monitor(
112
112
 
113
113
  请注意以下两点:
114
114
  - Mindspore功能在1.2.2版本后支持, <1.2.2版本不支持
115
- - 上述接口使用方式为1.2.2后更新的最新接口使用方式, <1.2.2版本的Pytorch旧接口使用方式为:
115
+ - 上述接口使用方式为1.2.2后更新的最新接口使用方式, <1.2.2版本的PyTorch旧接口使用方式为:
116
116
  ```Python
117
117
  from msprobe.pytorch import TrainerMon
118
118
  monitor = TrainerMon(
@@ -46,6 +46,8 @@ class DataNode:
46
46
  seen = set(op_name)
47
47
  while True:
48
48
  op_name = construct_info.get(op_name)
49
+ if isinstance(op_name, list):
50
+ op_name = op_name[0]
49
51
  if not op_name or op_name in seen:
50
52
  return construct
51
53
  construct.insert(0, op_name)
@@ -22,6 +22,7 @@ import torch.distributed as dist
22
22
 
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.file_utils import load_yaml
25
+ from msprobe.core.common.runtime import Runtime
25
26
  from msprobe.core.data_dump.api_registry import ApiRegistry
26
27
  from msprobe.pytorch.common.log import logger
27
28
  from msprobe.pytorch.common.utils import (
@@ -91,6 +92,12 @@ _inner_used_api = {
91
92
  }
92
93
 
93
94
 
95
+ def reset_dist_collect_func():
96
+ global dist_data_collect_func, dist_batch_data_collect_func
97
+ dist_data_collect_func.clear()
98
+ dist_batch_data_collect_func.clear()
99
+
100
+
94
101
  @parameter_adapter
95
102
  def tensor_module_forward(module, *args, **kwargs):
96
103
  return module.api_func(*args, **kwargs)
@@ -114,9 +121,9 @@ def dist_module_forward(module, *args, **kwargs):
114
121
 
115
122
  return store_data
116
123
 
117
- if use_async_op_flag or module.api_name in ['isend', 'irecv']:
124
+ if Runtime.is_running and (use_async_op_flag or module.api_name in ['isend', 'irecv']):
118
125
  dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook)
119
- if module.api_name == 'batch_isend_irecv':
126
+ if Runtime.is_running and module.api_name == 'batch_isend_irecv':
120
127
  dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)])
121
128
  return handle
122
129
 
@@ -18,7 +18,12 @@ from msprobe.core.service import BaseService
18
18
  from msprobe.pytorch.common.log import logger
19
19
  from msprobe.pytorch.common.utils import get_rank_if_initialized
20
20
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
21
- from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
21
+ from msprobe.pytorch.hook_module.api_register import (
22
+ get_api_register,
23
+ ApiTemplate,
24
+ redirect_wait,
25
+ reset_dist_collect_func
26
+ )
22
27
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
28
  from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
24
29
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
@@ -33,7 +38,7 @@ class PytorchService(BaseService):
33
38
  @staticmethod
34
39
  def _get_current_rank():
35
40
  return get_rank_if_initialized()
36
-
41
+
37
42
  def reset_status(self):
38
43
  self._reset_status()
39
44
 
@@ -59,8 +64,8 @@ class PytorchService(BaseService):
59
64
  self.module_processor.register_module_hook(self.model, self.build_hook)
60
65
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
61
66
 
62
-
63
67
  def _reset_status(self):
64
68
  super()._reset_status()
65
69
  ModuleProcesser.reset_module_stats()
66
70
  HOOKModule.reset_module_stats()
71
+ reset_dist_collect_func()
@@ -17,6 +17,7 @@ import os
17
17
  import sqlite3
18
18
  import json
19
19
  import re
20
+ import time
20
21
  from msprobe.core.common.log import logger
21
22
  from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker
22
23
  from msprobe.core.common.const import FileCheckConst
@@ -133,33 +134,56 @@ def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False):
133
134
 
134
135
 
135
136
  def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000):
137
+ max_retries = 10
138
+ initial_delay = 0.1
136
139
  if not os.path.exists(db_path):
137
140
  check_path_before_create(db_path)
138
141
  else:
139
142
  FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE,
140
143
  FileCheckConst.DB_SUFFIX).common_check()
141
- try:
142
- conn = sqlite3.connect(db_path)
143
- except sqlite3.Error as e:
144
- logger.error(f"Unable to create database connection: {e}")
145
- raise RuntimeError("Unable to create database connection") from e
146
144
 
147
- try:
148
- cursor = conn.cursor()
149
- cursor.execute(create_table_sql)
150
- if len(data) == 1:
151
- cursor.execute(insert_sql, data[0])
152
- conn.commit()
153
- else:
145
+ retry_count = 0
146
+ current_delay = initial_delay
147
+
148
+ while retry_count <= max_retries:
149
+ conn = None
150
+ try:
151
+ conn = sqlite3.connect(db_path, timeout=30)
152
+ cursor = conn.cursor()
153
+ # 启用WAL模式提升多进程读写并发能力
154
+ cursor.execute("PRAGMA journal_mode=WAL")
155
+ cursor.execute("PRAGMA synchronous=NORMAL")
156
+ cursor.execute(create_table_sql)
154
157
  for i in range(0, len(data), db_insert_size):
155
158
  batch = data[i:i + db_insert_size]
156
159
  cursor.executemany(insert_sql, batch)
157
- conn.commit()
158
- except sqlite3.Error as e:
159
- logger.error(f"An sqlite3 error occurred: {e}")
160
- raise RuntimeError("An sqlite3 error occurred") from e
161
- finally:
162
- conn.close()
160
+ conn.commit()
161
+ return
162
+ except sqlite3.OperationalError as e:
163
+ if "database is locked" in str(e).lower():
164
+ retry_count += 1
165
+ if retry_count > max_retries:
166
+ logger.error(f"Database lock conflict retry attempts exhausted ({max_retries}): {e}")
167
+ raise RuntimeError(f"DB lock retry exhausted: {e}") from e
168
+
169
+ logger.warning(
170
+ f"DB lock conflict (retry {retry_count}/{max_retries}), wait {current_delay:.2f}s : {e}"
171
+ )
172
+ time.sleep(current_delay)
173
+ current_delay *= 2
174
+ continue
175
+
176
+ logger.error(f"An sqlite3 error occurred: {e}")
177
+ raise e
178
+ except sqlite3.Error as e:
179
+ logger.error(f"An sqlite3 error occurred: {e}")
180
+ raise e
181
+ except Exception as e:
182
+ logger.error(f"An unknown error occurred: {e}")
183
+ raise e
184
+ finally:
185
+ if conn:
186
+ conn.close()
163
187
 
164
188
 
165
189
  def add_table_index(db_path):
@@ -242,11 +242,15 @@ def _compare_graph_ranks(input_param, args, step=None):
242
242
  def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
243
243
  dump_rank_n = input_param.get('npu_path')
244
244
  dump_rank_b = input_param.get('bench_path')
245
- npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
246
- bench_ranks = sorted(check_and_return_dir_contents(dump_rank_b, Const.RANK))
245
+ npu_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_n, Const.RANK))
246
+ bench_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_b, Const.RANK))
247
247
  if npu_ranks != bench_ranks:
248
- logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
249
- raise CompareException(CompareException.INVALID_PATH_ERROR)
248
+ intersection_ranks = sort_rank_number_strings(list(set(npu_ranks) & set(bench_ranks)))
249
+ if not intersection_ranks:
250
+ logger.error('The ranks in the two runs are completely different. Unable to match the ranks.')
251
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
252
+ npu_ranks = intersection_ranks
253
+ bench_ranks = intersection_ranks
250
254
  compare_graph_results = []
251
255
  if is_real_data_compare(input_param, npu_ranks, bench_ranks):
252
256
  mp_task_dict = {}
@@ -282,12 +286,16 @@ def _compare_graph_steps(input_param, args):
282
286
  dump_step_n = input_param.get('npu_path')
283
287
  dump_step_b = input_param.get('bench_path')
284
288
 
285
- npu_steps = sorted(check_and_return_dir_contents(dump_step_n, Const.STEP))
286
- bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
289
+ npu_steps = check_and_return_dir_contents(dump_step_n, Const.STEP)
290
+ bench_steps = check_and_return_dir_contents(dump_step_b, Const.STEP)
287
291
 
288
292
  if npu_steps != bench_steps:
289
- logger.error('The number of steps in the two runs is different. Unable to match the steps.')
290
- raise CompareException(CompareException.INVALID_PATH_ERROR)
293
+ intersection_steps = sort_rank_number_strings(list(set(npu_steps) & set(bench_steps)))
294
+
295
+ if not intersection_steps:
296
+ logger.error('The steps in the two runs are completely different. Unable to match the steps.')
297
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
298
+ npu_steps = intersection_steps
291
299
 
292
300
  args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
293
301
 
@@ -355,8 +363,10 @@ def _build_graph_steps(dump_steps_path, args):
355
363
  _build_graph_ranks(dump_ranks_path, args, step)
356
364
 
357
365
 
358
- def _compare_and_export_graph(graph_task_info, input_param, args):
366
+ def _compare_and_export_graph(graph_task_info, input_param, args, step=None):
359
367
  result = _run_graph_compare(graph_task_info, input_param, args)
368
+ if step is not None:
369
+ result.step = get_step_or_rank_int(step)
360
370
  return _export_compare_graph_result(args, result)
361
371
 
362
372
 
@@ -413,7 +423,7 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
413
423
  _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
414
424
  f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
415
425
  export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
416
- args=(graph_task_info, input_param, serializable_args),
426
+ args=(graph_task_info, input_param, serializable_args, step),
417
427
  error_callback=err_call))
418
428
  export_res_list = [res.get() for res in export_res_task_list]
419
429
  if any(export_res_list):