mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
msprobe/README.md CHANGED
@@ -35,17 +35,17 @@ export MSPROBE_LOG_LEVEL={x}
35
35
 
36
36
  ## 环境和依赖
37
37
 
38
- - 硬件环境请参见《[昇腾产品形态说明](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fquickstart%2Fquickstart%2Fquickstart_18_0002.html)》。
39
- - 软件环境请参见《[CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fsoftwareinst%2Finstg%2Finstg_0000.html%3FMode%3DPmIns%26OS%3DUbuntu%26Software%3DcannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。
38
+ - 硬件环境请参见《[昇腾产品形态说明](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/quickstart/quickstart/quickstart_18_0002.html)》。
39
+ - 软件环境请参见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/softwareinst/instg/instg_0000.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。
40
40
 
41
41
  以上环境依赖请根据实际环境选择适配的版本。
42
42
 
43
43
  ## 版本配套说明
44
44
 
45
- - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。
45
+ - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitcode.com/Ascend/pytorch)》。
46
46
  - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。
47
47
  - msprobe支持MSAdapter 2.1.0。
48
- - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
48
+ - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://www.hiascend.com/hardware/firmware-drivers/community?product=2&model=28&cann=8.0.RC3.alpha003&driver=1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
49
49
 
50
50
 
51
51
  ## 🚨 工具限制与注意事项
@@ -84,7 +84,7 @@ msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API
84
84
 
85
85
  精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
86
86
 
87
- PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md)
87
+ PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)
88
88
 
89
89
  MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
90
90
 
@@ -165,7 +165,7 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
165
165
 
166
166
  训练前或精度比对前,对比两个环境下可能影响训练精度的配置差异。
167
167
 
168
- [PyTorch 训练前配置检查](./docs/31.config_check.md)
168
+ [训练前配置检查](./docs/31.config_check.md)
169
169
 
170
170
  训练过程中或结束后,比较两个不同的checkpoint,评估模型相似度。
171
171
 
@@ -24,6 +24,8 @@ class Const:
24
24
  Class for const
25
25
  """
26
26
  TOOL_NAME = "msprobe"
27
+ MD5_INDEX = "md5_index"
28
+ MD5 = "md5"
27
29
 
28
30
  ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
29
31
  SEP = "."
@@ -52,9 +54,9 @@ class Const:
52
54
  SIX_SEGMENT = 6
53
55
  SEVEN_SEGMENT = 7
54
56
 
55
- MAX_DEPTH = 10
57
+ MAX_DEPTH = 400
56
58
  CPU_QUARTER = 4
57
- DUMP_MAX_DEPTH = 50
59
+ DUMP_MAX_DEPTH = 400
58
60
 
59
61
  EXTERN_INPUT_LIST_MAX_LEN = 100
60
62
  MAX_PROCESS_NUM = 128
@@ -72,6 +74,7 @@ class Const:
72
74
  ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
73
75
  SUMMARY = "summary"
74
76
  MD5 = "md5"
77
+ HASH = "hash"
75
78
  VALUE = "value"
76
79
  SUMMARY_MODE = ["statistics", "md5"]
77
80
 
@@ -113,9 +116,13 @@ class Const:
113
116
  RUN_UT = "run_ut"
114
117
  GRAD_PROBE = "grad_probe"
115
118
  STRUCTURE = "structure"
116
- TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE]
119
+ EXCEPTION_DUMP = "exception_dump"
120
+ DUMP_PRECISION_HIGH = "high"
121
+ DUMP_PRECISION_LOW = "low"
122
+ TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE, EXCEPTION_DUMP]
117
123
  DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE]
118
124
  DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
125
+ DUMP_PRECISION_LIST = [DUMP_PRECISION_LOW, DUMP_PRECISION_HIGH]
119
126
  LEVEL_L0 = "L0"
120
127
  LEVEL_L1 = "L1"
121
128
  LEVEL_L2 = "L2"
@@ -237,7 +244,11 @@ class Const:
237
244
  MEAN = 'Mean'
238
245
  NORM = 'Norm'
239
246
  DATA_NAME = 'data_name'
247
+ STATE = 'state'
248
+ REQ_GRAD = 'requires_grad'
249
+ API_ORIGIN_NAME = 'api_origin_name'
240
250
  TENSOR_STAT_INDEX = 'tensor_stat_index'
251
+ SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM]
241
252
 
242
253
  CODE_STACK = 'Code Stack'
243
254
  OP_NAME = 'Op Name'
@@ -260,8 +271,15 @@ class Const:
260
271
 
261
272
  TENSOR_STAT_LEN = 2
262
273
 
274
+ TENSOR_TYPE = "torch.Tensor"
275
+ DTENSOR_TYPE = "torch.distributed.tensor.DTensor"
276
+ FAKE_TENSOR_TYPE = "torch._subclasses.fake_tensor.FakeTensor"
277
+ AC_TENSOR_TYPE = "torch.distributed._functional_collectives.AsyncCollectiveTensor"
278
+
263
279
  SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml"
264
280
 
281
+ API_ATTR_LIST = ["__name__", "default"]
282
+
265
283
  PT_API_TYPE_FUNCTIONAL = "functional"
266
284
  PT_API_TYPE_TENSOR = "tensor"
267
285
  PT_API_TYPE_TORCH = "torch"
@@ -355,22 +373,22 @@ class Const:
355
373
  }
356
374
 
357
375
  def _fused_adamw_(
358
- self,
359
- grads,
360
- exp_avgs,
361
- exp_avg_sqs,
362
- max_exp_avg_sqs,
363
- state_steps,
364
- *,
365
- lr,
366
- beta1,
367
- beta2,
368
- weight_decay,
369
- eps,
370
- amsgrad,
371
- maximize,
372
- grad_scale=None,
373
- found_inf=None
376
+ self,
377
+ grads,
378
+ exp_avgs,
379
+ exp_avg_sqs,
380
+ max_exp_avg_sqs,
381
+ state_steps,
382
+ *,
383
+ lr,
384
+ beta1,
385
+ beta2,
386
+ weight_decay,
387
+ eps,
388
+ amsgrad,
389
+ maximize,
390
+ grad_scale=None,
391
+ found_inf=None
374
392
  ):
375
393
  pass
376
394
 
@@ -382,6 +400,13 @@ class Const:
382
400
  MATCH_MODE_NAME = "pure name"
383
401
  MATCH_MODE_MAPPING = "mapping"
384
402
  MATCH_MODE_SIMILARITY = "similarity"
403
+ CONFIG_CHECK_PASS = "pass"
404
+ CONFIG_CHECK_WARNING = "warning"
405
+ CONFIG_CHECK_ERROR = "error"
406
+
407
+ MIX_DUMP_NAMES = {'graph', 'pynative'}
408
+
409
+ MEGATRON_MICRO_STEP_NUMBER = 'megatron_micro_step_number'
385
410
 
386
411
 
387
412
  class CompareConst:
@@ -397,10 +422,14 @@ class CompareConst:
397
422
  BENCH_DTYPE = "Bench Dtype"
398
423
  NPU_SHAPE = "NPU Tensor Shape"
399
424
  BENCH_SHAPE = "Bench Tensor Shape"
425
+ NPU_CSV_FILE = "NPU CSV File"
426
+ BENCH_CSV_FILE = "Bench CSV File"
400
427
  NPU_MAX = "NPU max"
401
428
  NPU_MIN = "NPU min"
402
429
  NPU_MEAN = "NPU mean"
403
430
  NPU_NORM = "NPU l2norm"
431
+ NPU_P2POP_PEER = "NPU P2POp peer"
432
+
404
433
  BENCH_MAX = "Bench max"
405
434
  BENCH_MIN = "Bench min"
406
435
  BENCH_MEAN = "Bench mean"
@@ -416,6 +445,9 @@ class CompareConst:
416
445
  MIN_RELATIVE_ERR = "MinRelativeErr"
417
446
  MEAN_RELATIVE_ERR = "MeanRelativeErr"
418
447
  NORM_RELATIVE_ERR = "NormRelativeErr"
448
+ REQ_GRAD_CONSIST = "Requires_grad Consistent"
449
+ NPU_REQ_GRAD = "NPU Requires_grad"
450
+ BENCH_REQ_GRAD = "Bench Requires_grad"
419
451
  ACCURACY = "Accuracy Reached or Not"
420
452
  STACK = "NPU_Stack_Info"
421
453
  DATA_NAME = "Data_name"
@@ -437,7 +469,7 @@ class CompareConst:
437
469
  SUMMARY = "summary"
438
470
  COMPARE_RESULT = "compare_result"
439
471
  COMPARE_MESSAGE = "compare_message"
440
- MAX_EXCEL_LENGTH = 1048576
472
+ MAX_EXCEL_LENGTH = 1048500
441
473
  YES = "Yes"
442
474
  NO = "No"
443
475
  STATISTICS_INDICATOR_NUM = 4
@@ -485,21 +517,21 @@ class CompareConst:
485
517
 
486
518
  ULP_ERR_STATUS = "ulp_err_status"
487
519
 
488
- COMPARE_RESULT_HEADER = [
489
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST,
490
- MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
491
- NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
492
- ]
520
+ ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR,
521
+ ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
522
+ SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
523
+ MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
524
+ MD5_COMPARE_INDEX = [RESULT]
493
525
 
494
- SUMMARY_COMPARE_RESULT_HEADER = [
495
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
496
- MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR,
497
- NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE
498
- ]
526
+ BASIC_INFO = [NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_REQ_GRAD, BENCH_REQ_GRAD]
527
+ SUMMARY_INFO = [NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM]
499
528
 
500
- MD5_COMPARE_RESULT_HEADER = [
501
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
502
- ]
529
+ COMPARE_RESULT_HEADER = BASIC_INFO + ALL_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, ACCURACY, ERROR_MESSAGE]
530
+
531
+ SUMMARY_COMPARE_RESULT_HEADER = BASIC_INFO + SUMMARY_COMPARE_INDEX + SUMMARY_INFO + [REQ_GRAD_CONSIST, RESULT,
532
+ ERROR_MESSAGE]
533
+
534
+ MD5_COMPARE_RESULT_HEADER = BASIC_INFO + [NPU_MD5, BENCH_MD5, REQ_GRAD_CONSIST] + MD5_COMPARE_INDEX
503
535
 
504
536
  COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK]
505
537
 
@@ -513,11 +545,6 @@ class CompareConst:
513
545
  Const.MD5: MD5_COMPARE_RESULT_HEADER
514
546
  }
515
547
 
516
- ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
517
- FIVE_THOUSANDTHS_ERR_RATIO]
518
- SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
519
- MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
520
-
521
548
  # dtype match
522
549
 
523
550
  DTYPE_MATCH_GROUPS = [
@@ -554,6 +581,8 @@ class CompareConst:
554
581
  ULP_FLOAT16_THRESHOLD = 1
555
582
 
556
583
  # compare result data
584
+ NO_REAL_DATA = 'No real data'
585
+ API_UNMATCH = 'api unmatched'
557
586
  READ_NONE = 'No data'
558
587
  NONE = 'None'
559
588
  SHAPE_UNMATCH = 'shape unmatched'
@@ -622,6 +651,9 @@ class CompareConst:
622
651
  MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
623
652
  MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
624
653
  }
654
+ MS_GRAPH_CSV = {
655
+ NPU_CSV_FILE: None, BENCH_CSV_FILE: None
656
+ }
625
657
 
626
658
  API_MAPPING_KEYS_TO_COMPARE = [
627
659
  ('ms_args', 'pt_args'),
@@ -641,9 +673,11 @@ class CompareConst:
641
673
 
642
674
  OP_NAME_X = 'op_name_x'
643
675
  MATCH_RESULT_COLUMNS = [
644
- OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x',
676
+ OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'state_x', 'api_origin_name_x',
677
+ 'requires_grad_x', 'data_name_x',
645
678
  CMP_KEY, CMP_SHAPE,
646
- 'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y',
679
+ 'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'state_y', 'api_origin_name_y',
680
+ 'requires_grad_y', 'data_name_y'
647
681
  ]
648
682
 
649
683
  INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
@@ -674,6 +708,8 @@ class FileCheckConst:
674
708
  IR_SUFFIX = ".ir"
675
709
  ZIP_SUFFIX = ".zip"
676
710
  SHELL_SUFFIX = ".sh"
711
+ LOG_SUFFIX = ".log"
712
+ DB_SUFFIX = '.db'
677
713
  MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
678
714
  MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
679
715
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
@@ -686,6 +722,8 @@ class FileCheckConst:
686
722
  MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
687
723
  MAX_FILE_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
688
724
  COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
725
+ MAX_LOG_SIZE = 10737418240 # 1 * 1024 * 1024 * 1024
726
+ MAX_DB_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
689
727
  DIR = "dir"
690
728
  FILE = "file"
691
729
  DATA_DIR_AUTHORITY = 0o750
@@ -699,7 +737,9 @@ class FileCheckConst:
699
737
  XLSX_SUFFIX: MAX_XLSX_SIZE,
700
738
  YAML_SUFFIX: MAX_YAML_SIZE,
701
739
  IR_SUFFIX: MAX_IR_SIZE,
702
- ZIP_SUFFIX: MAX_ZIP_SIZE
740
+ ZIP_SUFFIX: MAX_ZIP_SIZE,
741
+ LOG_SUFFIX: MAX_LOG_SIZE,
742
+ DB_SUFFIX: MAX_DB_SIZE
703
743
  }
704
744
  CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
705
745
 
@@ -758,6 +798,11 @@ class MonitorConst:
758
798
  DEFAULT_STEP_INTERVAL = 1
759
799
 
760
800
  OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"]
801
+ OP_MONVIS_SUPPORTED = [
802
+ "norm", "min", "max", "zeros", "nans", "mean",
803
+ "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian",
804
+ "proxy", "token_similarity"
805
+ ]
761
806
  MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
762
807
  DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
763
808
  DATABASE = "database"
@@ -770,6 +815,8 @@ class MonitorConst:
770
815
  )
771
816
  DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
772
817
  RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan']
818
+ L2_HOOKS = ["linear_hook", "attention_hook"]
819
+ SA_ORDERS = ["s,b,h,d", "b,s,h,d"]
773
820
 
774
821
  SLICE_SIZE = 20480
775
822
  # used for name
@@ -781,6 +828,7 @@ class MonitorConst:
781
828
  ACTV_OUT = "output"
782
829
  ACTVGRAD_IN = "input_grad"
783
830
  ACTVGRAD_OUT = "output_grad"
831
+ FSDP_FLAT_SEP = "_fsdp_wrapped_module."
784
832
  # used for tasks
785
833
  ACTV = "actv"
786
834
  ACTVGRAD = "actv_grad"
@@ -820,3 +868,12 @@ class MonitorConst:
820
868
  TRAIN_STAGE[key] = BACKWARD_STAGE
821
869
  for key in OPTIMIZER_KEY:
822
870
  TRAIN_STAGE[key] = OPTIMIZER_STAGE
871
+
872
+ # csv2db
873
+ DEFAULT_INT_VALUE = 0
874
+ MAX_PROCESS_NUM = 128
875
+ CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv"
876
+ BATCH_SIZE = 10000
877
+ MAX_PARTITION = 10_000_000
878
+ MIN_PARTITION = 10
879
+
@@ -0,0 +1,256 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import sqlite3
17
+ from typing import List, Tuple, Dict, Any
18
+ from functools import wraps
19
+
20
+ from msprobe.pytorch.common.log import logger
21
+ from msprobe.core.common.file_utils import check_path_before_create, change_mode
22
+ from msprobe.core.common.const import FileCheckConst
23
+
24
+ SAFE_SQL_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
25
+
26
+
27
+ def check_identifier_safety(name):
28
+ """验证标识符是否安全(防止SQL注入)"""
29
+ if not isinstance(name, str) or SAFE_SQL_PATTERN.match(name) is None:
30
+ raise ValueError(f"Invalid SQL identifier: {name}, potential SQL injection risk!")
31
+
32
+
33
+ def _db_operation(func):
34
+ """数据库操作装饰器,自动管理连接"""
35
+ @wraps(func)
36
+ def wrapper(self, *args, **kwargs):
37
+ conn, curs = None, None
38
+ try:
39
+ conn, curs = self._get_connection()
40
+ result = func(self, conn, curs, *args, **kwargs)
41
+ return result # 显式返回正常结果
42
+
43
+ except sqlite3.Error as err:
44
+ logger.error(f"Database operation failed: {err}")
45
+ if conn:
46
+ conn.rollback()
47
+ return None # 显式返回错误情况下的None
48
+
49
+ finally:
50
+ self._release_connection(conn, curs)
51
+ return wrapper
52
+
53
+
54
+ class DBManager:
55
+ """
56
+ 数据库管理类,封装常用数据库操作
57
+ """
58
+
59
+ DEFAULT_FETCH_SIZE = 10000
60
+ DEFAULT_INSERT_SIZE = 10000
61
+ MAX_ROW_COUNT = 100000000
62
+
63
+ def __init__(self, db_path: str):
64
+ """
65
+ 初始化DBManager
66
+ :param db_path: 数据库文件路径
67
+ :param table_config: 表配置对象
68
+ """
69
+ self.db_path = db_path
70
+
71
+ @staticmethod
72
+ def _get_where_sql(where_list):
73
+ if not where_list:
74
+ return "", tuple()
75
+
76
+ where_clauses = []
77
+ where_values = []
78
+ if where_list:
79
+ for col, val in where_list.items():
80
+ check_identifier_safety(col)
81
+ where_clauses.append(f"{col} = ?")
82
+ where_values.append(val)
83
+ if where_clauses:
84
+ where_sql = " WHERE " + " AND ".join(where_clauses)
85
+ return where_sql, tuple(where_values)
86
+
87
+ @_db_operation
88
+ def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
89
+ table_name: str, data: List[Tuple], key_list: List[str] = None) -> int:
90
+ """
91
+ 批量插入数据
92
+ :param table_name: 表名
93
+ :param data: 要插入的数据列表
94
+ :param batch_size: 每批插入的大小
95
+ :return: 插入的行数
96
+ """
97
+ check_identifier_safety(table_name)
98
+
99
+ if not data:
100
+ return 0
101
+ columns = len(data[0])
102
+ if key_list:
103
+ if not isinstance(key_list, list):
104
+ raise TypeError(
105
+ f"key_list must be a list, got {type(key_list)}"
106
+ )
107
+ if columns != len(key_list):
108
+ raise ValueError(
109
+ f"When inserting into table {table_name}, the length of key list ({key_list})"
110
+ f"does not match the data({columns}).")
111
+ for key in key_list:
112
+ check_identifier_safety(key)
113
+
114
+ batch_size = self.DEFAULT_INSERT_SIZE
115
+ placeholders = ", ".join(["?"] * columns)
116
+ if key_list:
117
+ keys = ", ".join(key_list)
118
+ sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})"
119
+ else:
120
+ sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})"
121
+
122
+ inserted_rows = 0
123
+ for i in range(0, len(data), batch_size):
124
+ batch = data[i:i + batch_size]
125
+ curs.executemany(sql, batch)
126
+ inserted_rows += curs.rowcount
127
+
128
+ conn.commit()
129
+ return inserted_rows
130
+
131
+ @_db_operation
132
+ def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
133
+ table_name: str,
134
+ columns: List[str] = None,
135
+ where: dict = None) -> List[Dict]:
136
+ """
137
+ 查询数据
138
+ :param table_name: 表名
139
+ :param columns: 要查询的列
140
+ :param where: WHERE条件
141
+ :return: 查询结果列表(字典形式)
142
+ """
143
+ check_identifier_safety(table_name)
144
+
145
+ if not columns:
146
+ raise ValueError("columns parameter cannot be empty, specify columns to select (e.g. ['id', 'name'])")
147
+ if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns):
148
+ raise TypeError("columns must be a list of strings (e.g. ['id', 'name'])")
149
+
150
+ for col in columns:
151
+ check_identifier_safety(col)
152
+
153
+ cols = ", ".join(columns)
154
+ sql = f"SELECT {cols} FROM {table_name}"
155
+
156
+ where_sql, where_parems = self._get_where_sql(where)
157
+ curs.execute(sql + where_sql, where_parems)
158
+
159
+ return [dict(row) for row in curs.fetchall()]
160
+
161
+ @_db_operation
162
+ def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
163
+ table_name: str, updates: Dict[str, Any],
164
+ where: dict = None) -> int:
165
+ """
166
+ 更新数据
167
+ :param table_name: 表名
168
+ :param updates: 要更新的字段和值
169
+ :param where: WHERE条件
170
+ :param where_params: WHERE条件参数
171
+ :return: 影响的行数
172
+ """
173
+ check_identifier_safety(table_name)
174
+ if not updates:
175
+ raise ValueError("columns parameter cannot be empty, specify it to update (e.g. {'name': 'xxx'}")
176
+ if not isinstance(updates, dict):
177
+ raise TypeError(f"updates must be a dictionary, got: {type(updates)}")
178
+ for key in updates.keys():
179
+ check_identifier_safety(key)
180
+
181
+ set_clause = ", ".join([f"{k} = ?" for k in updates.keys()])
182
+ sql = f"UPDATE {table_name} SET {set_clause}"
183
+
184
+ params = tuple(updates.values())
185
+
186
+ where_sql, where_parems = self._get_where_sql(where)
187
+
188
+ curs.execute(sql + where_sql, params + where_parems)
189
+ conn.commit()
190
+ return curs.rowcount
191
+
192
+ @_db_operation
193
+ def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
194
+ sql: str, params: Tuple = None) -> List[Dict]:
195
+ """
196
+ 执行自定义SQL查询
197
+ :param sql: SQL语句
198
+ :param params: 参数
199
+ :return: 查询结果
200
+ """
201
+ curs.execute(sql, params or ())
202
+ if sql.strip().upper().startswith("SELECT"):
203
+ return [dict(row) for row in curs.fetchall()]
204
+ conn.commit()
205
+ return []
206
+
207
+ def table_exists(self, table_name: str) -> bool:
208
+ """
209
+ :param table_name: 表名
210
+ :return: 查询结果
211
+ """
212
+ result = self.select_data(
213
+ table_name="sqlite_master",
214
+ columns=["name"],
215
+ where={"type": "table", "name": table_name}
216
+ )
217
+ return len(result) > 0
218
+
219
+ @_db_operation
220
+ def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
221
+ sql_commands: List[str]) -> List[List[Dict]]:
222
+ """
223
+ 批量执行多个SQL语句
224
+ :param sql_commands: [sql1, sql2, ...]
225
+ :return: 每个SELECT语句的结果列表
226
+ """
227
+ results = []
228
+ for sql in sql_commands:
229
+ curs.execute(sql)
230
+ if sql.strip().upper().startswith("SELECT"):
231
+ results.append([dict(row) for row in curs.fetchall()])
232
+ conn.commit()
233
+ return results
234
+
235
+ def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
236
+ """获取数据库连接和游标"""
237
+ check_path_before_create(self.db_path)
238
+ try:
239
+ conn = sqlite3.connect(self.db_path)
240
+ conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果
241
+ curs = conn.cursor()
242
+ return conn, curs
243
+ except sqlite3.Error as err:
244
+ logger.error(f"Database connection failed: {err}")
245
+ raise
246
+
247
+ def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None:
248
+ """释放数据库连接"""
249
+ try:
250
+ if curs is not None:
251
+ curs.close()
252
+ if conn is not None:
253
+ conn.close()
254
+ except sqlite3.Error as err:
255
+ logger.error(f"Failed to release database connection: {err}")
256
+ change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -33,7 +33,7 @@ import pandas as pd
33
33
  from msprobe.core.common.decorator import recursion_depth_decorator
34
34
  from msprobe.core.common.log import logger
35
35
  from msprobe.core.common.exceptions import FileCheckException
36
- from msprobe.core.common.const import FileCheckConst, CompareConst
36
+ from msprobe.core.common.const import FileCheckConst, CompareConst, Const
37
37
  from msprobe.core.common.global_lock import global_lock, is_main_process
38
38
 
39
39
  proc_lock = multiprocessing.Lock()
@@ -172,7 +172,7 @@ def check_path_exists(path):
172
172
  if not os.path.exists(path):
173
173
  logger.error('The file path %s does not exist.' % path)
174
174
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
175
-
175
+
176
176
 
177
177
  def check_path_not_exists(path):
178
178
  if os.path.exists(path):
@@ -259,8 +259,8 @@ def check_path_type(file_path, file_type):
259
259
  def check_others_writable(directory):
260
260
  dir_stat = os.stat(directory)
261
261
  is_writable = (
262
- bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
263
- bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
262
+ bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
263
+ bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
264
264
  )
265
265
  return is_writable
266
266
 
@@ -319,7 +319,7 @@ def check_dirpath_before_read(path):
319
319
  check_path_owner_consistent(dirpath)
320
320
  except FileCheckException:
321
321
  logger.warning(f"The directory {dirpath} is not yours.")
322
-
322
+
323
323
 
324
324
  def check_file_or_directory_path(path, isdir=False):
325
325
  """
@@ -422,6 +422,26 @@ def load_json(json_path):
422
422
  return data
423
423
 
424
424
 
425
+ def load_construct_json(json_path):
426
+ construct_dict_o = load_json(json_path)
427
+ if Const.MEGATRON_MICRO_STEP_NUMBER in construct_dict_o:
428
+ construct_dict = {}
429
+ micro_step_dict = {Const.MEGATRON_MICRO_STEP_NUMBER: construct_dict_o.get(Const.MEGATRON_MICRO_STEP_NUMBER)}
430
+ del construct_dict_o[Const.MEGATRON_MICRO_STEP_NUMBER]
431
+ for key, value in construct_dict_o.items():
432
+ if isinstance(value, list):
433
+ if len(value) != 2:
434
+ logger.error(f'Parse construct json file "{os.path.basename(json_path)}" failed.')
435
+ raise RuntimeError()
436
+ construct_dict[key] = value[0]
437
+ micro_step_dict[key] = value[1]
438
+ else:
439
+ construct_dict[key] = value
440
+ micro_step_dict[key] = 0
441
+ return construct_dict, micro_step_dict
442
+ return construct_dict_o, {}
443
+
444
+
425
445
  def save_json(json_path, data, indent=None, mode="w"):
426
446
  check_path_before_create(json_path)
427
447
  json_path = os.path.realpath(json_path)
@@ -520,6 +540,9 @@ def move_directory(src_path, dst_path):
520
540
  check_file_or_directory_path(src_path, isdir=True)
521
541
  check_path_before_create(dst_path)
522
542
  try:
543
+ if os.path.exists(dst_path):
544
+ logger.warning(f"The destination directory {dst_path} already exists, it will be removed.")
545
+ shutil.rmtree(dst_path)
523
546
  shutil.move(src_path, dst_path)
524
547
  except Exception as e:
525
548
  logger.error(f"move directory {src_path} to {dst_path} failed")
@@ -89,6 +89,13 @@ class BaseLogger:
89
89
  self.error(msg)
90
90
  raise exception
91
91
 
92
+ def warning_log_with_exp(self, msg, exception):
93
+ """
94
+ 打印警告日志并抛出指定异常
95
+ """
96
+ self.warning(msg)
97
+ raise exception
98
+
92
99
  def _print_log(self, level, msg, end='\n'):
93
100
  current_rank = self.get_rank()
94
101
  current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())