mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from functools import wraps
18
+
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.core.common.exceptions import MsprobeException
21
+ from msprobe.core.common.log import logger
22
+
23
+ # 记录工具函数递归的深度
24
+ recursion_depth = defaultdict(int)
25
+
26
+
27
+ def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
28
+ """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
29
+ def decorator(func):
30
+ @wraps(func)
31
+ def wrapper(*args, **kwargs):
32
+ func_id = id(func)
33
+ recursion_depth[func_id] += 1
34
+ if recursion_depth[func_id] > max_depth:
35
+ msg = f"call {func_info} exceeds the recursion limit."
36
+ logger.error_log_with_exp(
37
+ msg,
38
+ MsprobeException(
39
+ MsprobeException.RECURSION_LIMIT_ERROR, msg
40
+ ),
41
+ )
42
+ try:
43
+ result = func(*args, **kwargs)
44
+ finally:
45
+ recursion_depth[func_id] -= 1
46
+ return result
47
+
48
+ return wrapper
49
+
50
+ return decorator
@@ -21,19 +21,21 @@ class CodedException(Exception):
21
21
 
22
22
  def __str__(self):
23
23
  return self.error_info
24
-
25
-
24
+
25
+
26
26
  class MsprobeException(CodedException):
27
27
  INVALID_PARAM_ERROR = 0
28
28
  OVERFLOW_NUMS_ERROR = 1
29
29
  RECURSION_LIMIT_ERROR = 2
30
30
  INTERFACE_USAGE_ERROR = 3
31
+ UNSUPPORTED_TYPE_ERROR = 4
31
32
 
32
33
  err_strs = {
33
34
  INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
34
35
  OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
35
36
  RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
36
- INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
37
+ INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ",
38
+ UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: "
37
39
  }
38
40
 
39
41
 
@@ -12,23 +12,31 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import atexit
16
16
  import csv
17
17
  import fcntl
18
+ import io
18
19
  import os
20
+ import pickle
21
+ from multiprocessing import shared_memory
19
22
  import stat
20
23
  import json
21
24
  import re
22
25
  import shutil
23
- from datetime import datetime, timezone
24
- from dateutil import parser
26
+ import sys
27
+ import zipfile
28
+ import multiprocessing
25
29
  import yaml
26
30
  import numpy as np
27
31
  import pandas as pd
28
32
 
33
+ from msprobe.core.common.decorator import recursion_depth_decorator
29
34
  from msprobe.core.common.log import logger
30
35
  from msprobe.core.common.exceptions import FileCheckException
31
36
  from msprobe.core.common.const import FileCheckConst
37
+ from msprobe.core.common.global_lock import global_lock, is_main_process
38
+
39
+ proc_lock = multiprocessing.Lock()
32
40
 
33
41
 
34
42
  class FileChecker:
@@ -164,6 +172,12 @@ def check_path_exists(path):
164
172
  if not os.path.exists(path):
165
173
  logger.error('The file path %s does not exist.' % path)
166
174
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
175
+
176
+
177
+ def check_path_not_exists(path):
178
+ if os.path.exists(path):
179
+ logger.error('The file path %s already exist.' % path)
180
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
167
181
 
168
182
 
169
183
  def check_path_readability(path):
@@ -266,6 +280,7 @@ def make_dir(dir_path):
266
280
  file_check.common_check()
267
281
 
268
282
 
283
+ @recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
269
284
  def create_directory(dir_path):
270
285
  """
271
286
  Function Description:
@@ -297,12 +312,13 @@ def check_path_before_create(path):
297
312
  def check_dirpath_before_read(path):
298
313
  path = os.path.realpath(path)
299
314
  dirpath = os.path.dirname(path)
300
- if check_others_writable(dirpath):
301
- logger.warning(f"The directory is writable by others: {dirpath}.")
302
- try:
303
- check_path_owner_consistent(dirpath)
304
- except FileCheckException:
305
- logger.warning(f"The directory {dirpath} is not yours.")
315
+ if dedup_log('check_dirpath_before_read', dirpath):
316
+ if check_others_writable(dirpath):
317
+ logger.warning(f"The directory is writable by others: {dirpath}.")
318
+ try:
319
+ check_path_owner_consistent(dirpath)
320
+ except FileCheckException:
321
+ logger.warning(f"The directory {dirpath} is not yours.")
306
322
 
307
323
 
308
324
  def check_file_or_directory_path(path, isdir=False):
@@ -332,6 +348,23 @@ def change_mode(path, mode):
332
348
  'Failed to change {} authority. {}'.format(path, str(ex))) from ex
333
349
 
334
350
 
351
+ @recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod')
352
+ def recursive_chmod(path):
353
+ """
354
+ 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750
355
+
356
+ :param path: 要修改权限的目录路径
357
+ """
358
+ for _, dirs, files in os.walk(path):
359
+ for file_name in files:
360
+ file_path = os.path.join(path, file_name)
361
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
362
+ for dir_name in dirs:
363
+ dir_path = os.path.join(path, dir_name)
364
+ change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY)
365
+ recursive_chmod(dir_path)
366
+
367
+
335
368
  def path_len_exceeds_limit(file_path):
336
369
  return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
337
370
  len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
@@ -446,6 +479,15 @@ def save_excel(path, data):
446
479
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
447
480
 
448
481
 
482
+ def move_directory(src_path, dst_path):
483
+ check_file_or_directory_path(src_path, isdir=True)
484
+ check_path_before_create(dst_path)
485
+ try:
486
+ shutil.move(src_path, dst_path)
487
+ except Exception as e:
488
+ logger.error(f"move directory {src_path} to {dst_path} failed")
489
+ raise RuntimeError(f"move directory {src_path} to {dst_path} failed") from e
490
+ change_mode(dst_path, FileCheckConst.DATA_DIR_AUTHORITY)
449
491
 
450
492
 
451
493
  def move_file(src_path, dst_path):
@@ -511,7 +553,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
511
553
  if not isinstance(value, str):
512
554
  return True
513
555
  try:
514
- # -1.00 or +1.00 should be consdiered as digit numbers
556
+ # -1.00 or +1.00 should be considered as digit numbers
515
557
  float(value)
516
558
  except ValueError:
517
559
  # otherwise, they will be considered as formular injections
@@ -557,7 +599,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
557
599
  if not isinstance(value, str):
558
600
  return True
559
601
  try:
560
- # -1.00 or +1.00 should be consdiered as digit numbers
602
+ # -1.00 or +1.00 should be considered as digit numbers
561
603
  float(value)
562
604
  except ValueError:
563
605
  # otherwise, they will be considered as formular injections
@@ -588,8 +630,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
588
630
  def remove_path(path):
589
631
  if not os.path.exists(path):
590
632
  return
633
+ if os.path.islink(path):
634
+ logger.error(f"Failed to delete {path}, it is a symbolic link.")
635
+ raise RuntimeError("Delete file or directory failed.")
591
636
  try:
592
- if os.path.islink(path) or os.path.isfile(path):
637
+ if os.path.isfile(path):
593
638
  os.remove(path)
594
639
  else:
595
640
  shutil.rmtree(path)
@@ -598,7 +643,7 @@ def remove_path(path):
598
643
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err
599
644
  except Exception as e:
600
645
  logger.error("Failed to delete {}. Please check.".format(path))
601
- raise RuntimeError(f"Delete {path} failed.") from e
646
+ raise RuntimeError("Delete file or directory failed.") from e
602
647
 
603
648
 
604
649
  def get_json_contents(file_path):
@@ -632,42 +677,231 @@ def os_walk_for_files(path, depth):
632
677
  return res
633
678
 
634
679
 
635
- def check_crt_valid(pem_path):
680
+ def read_xlsx(file_path, sheet_name=None):
681
+ check_file_or_directory_path(file_path)
682
+ try:
683
+ if sheet_name:
684
+ result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name)
685
+ else:
686
+ result_df = pd.read_excel(file_path, keep_default_na=False)
687
+ except Exception as e:
688
+ logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
689
+ raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
690
+ return result_df
691
+
692
+
693
+ def create_file_with_list(result_list, filepath):
694
+ check_path_before_create(filepath)
695
+ filepath = os.path.realpath(filepath)
696
+ try:
697
+ with FileOpen(filepath, 'w', encoding='utf-8') as file:
698
+ fcntl.flock(file, fcntl.LOCK_EX)
699
+ for item in result_list:
700
+ file.write(item + '\n')
701
+ fcntl.flock(file, fcntl.LOCK_UN)
702
+ except Exception as e:
703
+ logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.')
704
+ raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e
705
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
706
+
707
+
708
+ def create_file_with_content(data, filepath):
709
+ check_path_before_create(filepath)
710
+ filepath = os.path.realpath(filepath)
711
+ try:
712
+ with FileOpen(filepath, 'w', encoding='utf-8') as file:
713
+ fcntl.flock(file, fcntl.LOCK_EX)
714
+ file.write(data)
715
+ fcntl.flock(file, fcntl.LOCK_UN)
716
+ except Exception as e:
717
+ logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.')
718
+ raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e
719
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
720
+
721
+
722
+ def add_file_to_zip(zip_file_path, file_path, arc_path=None):
636
723
  """
637
- Check the validity of the SSL certificate.
724
+ Add a file to a ZIP archive, if zip does not exist, create one.
638
725
 
639
- Load the SSL certificate from the specified path, parse and check its validity period.
640
- If the certificate is expired or invalid, raise a RuntimeError.
726
+ :param zip_file_path: Path to the ZIP archive
727
+ :param file_path: Path to the file to add
728
+ :param arc_path: Optional path inside the ZIP archive where the file should be added
729
+ """
730
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
731
+ check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE)
732
+ zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
733
+ if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE:
734
+ raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
735
+ check_path_before_create(zip_file_path)
736
+ try:
737
+ proc_lock.acquire()
738
+ with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
739
+ zip_file.write(file_path, arc_path)
740
+ except Exception as e:
741
+ logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.')
742
+ raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e
743
+ finally:
744
+ proc_lock.release()
745
+ change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
641
746
 
642
- Parameters:
643
- pem_path (str): The file path of the SSL certificate.
644
747
 
645
- Raises:
646
- RuntimeError: If the SSL certificate is invalid or expired.
748
+ def create_file_in_zip(zip_file_path, file_name, content):
647
749
  """
648
- import OpenSSL
750
+ Create a file with content inside a ZIP archive.
751
+
752
+ :param zip_file_path: Path to the ZIP archive
753
+ :param file_name: Name of the file to create
754
+ :param content: Content to write to the file
755
+ """
756
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
757
+ check_path_before_create(zip_file_path)
758
+ zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
759
+ if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE:
760
+ raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
649
761
  try:
650
- with FileOpen(pem_path, "r") as f:
651
- pem_data = f.read()
652
- cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
653
- pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
654
- pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
655
- logger.info(f"The SSL certificate passes the verification and the validity period "
656
- f"starts from {pem_start} ends at {pem_end}.")
762
+ with open(zip_file_path, 'a+') as f: # 必须用 'a+' 模式才能 flock
763
+ # 2. 获取排他锁(阻塞直到成功)
764
+ fcntl.flock(f, fcntl.LOCK_EX) # LOCK_EX: 独占锁
765
+ with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
766
+ zip_info = zipfile.ZipInfo(file_name)
767
+ zip_info.compress_type = zipfile.ZIP_DEFLATED
768
+ zip_file.writestr(zip_info, content)
769
+ fcntl.flock(f, fcntl.LOCK_UN)
657
770
  except Exception as e:
658
- logger.error("Failed to parse the SSL certificate. Check the certificate.")
659
- raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
771
+ logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
772
+ raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
773
+ change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
660
774
 
661
- now_utc = datetime.now(tz=timezone.utc)
662
- if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
663
- raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
664
775
 
776
+ def extract_zip(zip_file_path, extract_dir):
777
+ """
778
+ Extract the contents of a ZIP archive to a specified directory.
665
779
 
666
- def read_xlsx(file_path):
667
- check_file_or_directory_path(file_path)
780
+ :param zip_file_path: Path to the ZIP archive
781
+ :param extract_dir: Directory to extract the contents to
782
+ """
783
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
668
784
  try:
669
- result_df = pd.read_excel(file_path, keep_default_na=False)
785
+ proc_lock.acquire()
786
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
787
+ total_size = 0
788
+ if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
789
+ raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}")
790
+ for file_info in zip_file.infolist():
791
+ if file_info.file_size > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
792
+ raise ValueError(f"File {file_info.filename} is too large to extract")
793
+
794
+ total_size += file_info.file_size
795
+ if total_size > FileCheckConst.MAX_ZIP_SIZE:
796
+ raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
670
797
  except Exception as e:
671
- logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
672
- raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
673
- return result_df
798
+ logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
799
+ raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
800
+ finally:
801
+ proc_lock.release()
802
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
803
+ zip_file.extractall(extract_dir)
804
+
805
+
806
+ def split_zip_file_path(zip_file_path):
807
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
808
+ zip_file_path = os.path.realpath(zip_file_path)
809
+ return os.path.dirname(zip_file_path), os.path.basename(zip_file_path)
810
+
811
+
812
+ def dedup_log(func_name, filter_name):
813
+ with SharedDict() as shared_dict:
814
+ exist_names = shared_dict.get(func_name, set())
815
+ if filter_name in exist_names:
816
+ return False
817
+ exist_names.add(filter_name)
818
+ shared_dict[func_name] = exist_names
819
+ return True
820
+
821
+
822
+ class SharedDict:
823
+ def __init__(self):
824
+ self._changed = False
825
+ self._dict = None
826
+ self._shm = None
827
+
828
+ def __enter__(self):
829
+ self._load_shared_memory()
830
+ return self
831
+
832
+ def __exit__(self, exc_type, exc_val, exc_tb):
833
+ try:
834
+ if self._changed:
835
+ data = pickle.dumps(self._dict)
836
+ global_lock.acquire()
837
+ try:
838
+ self._shm.buf[0:len(data)] = bytearray(data)
839
+ finally:
840
+ global_lock.release()
841
+ self._shm.close()
842
+ except FileNotFoundError:
843
+ name = self.get_shared_memory_name()
844
+ logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.')
845
+
846
+ def __setitem__(self, key, value):
847
+ self._dict[key] = value
848
+ self._changed = True
849
+
850
+ def __contains__(self, item):
851
+ return item in self._dict
852
+
853
+ @classmethod
854
+ def destroy_shared_memory(cls):
855
+ if is_main_process():
856
+ name = cls.get_shared_memory_name()
857
+ try:
858
+ shm = shared_memory.SharedMemory(create=False, name=name)
859
+ shm.close()
860
+ shm.unlink()
861
+ logger.debug(f'destroy shared memory, name: {name}')
862
+ except FileNotFoundError:
863
+ logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.')
864
+
865
+ @classmethod
866
+ def get_shared_memory_name(cls):
867
+ if is_main_process():
868
+ return f'shared_memory_{os.getpid()}'
869
+ return f'shared_memory_{os.getppid()}'
870
+
871
+ def get(self, key, default=None):
872
+ return self._dict.get(key, default)
873
+
874
+ def _load_shared_memory(self):
875
+ name = self.get_shared_memory_name()
876
+ try:
877
+ self._shm = shared_memory.SharedMemory(create=False, name=name)
878
+ except FileNotFoundError:
879
+ try:
880
+ self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024)
881
+ data = pickle.dumps({})
882
+ self._shm.buf[0:len(data)] = bytearray(data)
883
+ logger.debug(f'create shared memory, name: {name}')
884
+ except FileExistsError:
885
+ self._shm = shared_memory.SharedMemory(create=False, name=name)
886
+ self._safe_load()
887
+
888
+ def _safe_load(self):
889
+ with io.BytesIO(self._shm.buf[:]) as buff:
890
+ try:
891
+ self._dict = SafeUnpickler(buff).load()
892
+ except Exception as e:
893
+ logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.')
894
+ self._dict = {}
895
+ self._changed = True
896
+
897
+
898
+ class SafeUnpickler(pickle.Unpickler):
899
+ WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}}
900
+
901
+ def find_class(self, module, name):
902
+ if module in self.WHITELIST and name in self.WHITELIST[module]:
903
+ return super().find_class(module, name)
904
+ raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!')
905
+
906
+
907
+ atexit.register(SharedDict.destroy_shared_memory)
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.import functools
15
+ import functools
16
+ from msprobe.core.common.const import Const
17
+ from msprobe.core.common.file_utils import check_file_or_directory_path
18
+ from msprobe.core.common.file_utils import save_npy
19
+
20
+
21
+ class FrameworkDescriptor:
22
+ def __get__(self, instance, owner):
23
+ if owner._framework is None:
24
+ owner.import_framework()
25
+ return owner._framework
26
+
27
+
28
+ class FmkAdp:
29
+ fmk = Const.PT_FRAMEWORK
30
+ supported_fmk = [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]
31
+ supported_dtype_list = ["bfloat16", "float16", "float32", "float64"]
32
+ _framework = None
33
+ framework = FrameworkDescriptor()
34
+
35
+ @classmethod
36
+ def import_framework(cls):
37
+ if cls.fmk == Const.PT_FRAMEWORK:
38
+ import torch
39
+ cls._framework = torch
40
+ elif cls.fmk == Const.MS_FRAMEWORK:
41
+ import mindspore
42
+ cls._framework = mindspore
43
+ else:
44
+ raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
45
+
46
+ @classmethod
47
+ def set_fmk(cls, fmk=Const.PT_FRAMEWORK):
48
+ if fmk not in cls.supported_fmk:
49
+ raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
50
+ cls.fmk = fmk
51
+ cls._framework = None # 重置框架,以便下次访问时重新导入
52
+
53
+ @classmethod
54
+ def get_rank(cls):
55
+ if cls.fmk == Const.PT_FRAMEWORK:
56
+ return cls.framework.distributed.get_rank()
57
+ return cls.framework.communication.get_rank()
58
+
59
+ @classmethod
60
+ def get_rank_id(cls):
61
+ if cls.is_initialized():
62
+ return cls.get_rank()
63
+ return 0
64
+
65
+ @classmethod
66
+ def is_initialized(cls):
67
+ if cls.fmk == Const.PT_FRAMEWORK:
68
+ return cls.framework.distributed.is_initialized()
69
+ return cls.framework.communication.GlobalComm.INITED
70
+
71
+ @classmethod
72
+ def is_nn_module(cls, module):
73
+ if cls.fmk == Const.PT_FRAMEWORK:
74
+ return isinstance(module, cls.framework.nn.Module)
75
+ return isinstance(module, cls.framework.nn.Cell)
76
+
77
+ @classmethod
78
+ def is_tensor(cls, tensor):
79
+ if cls.fmk == Const.PT_FRAMEWORK:
80
+ return isinstance(tensor, cls.framework.Tensor)
81
+ return isinstance(tensor, cls.framework.Tensor)
82
+
83
+ @classmethod
84
+ def process_tensor(cls, tensor, func):
85
+ if cls.fmk == Const.PT_FRAMEWORK:
86
+ if not tensor.is_floating_point() or tensor.dtype == cls.framework.float64:
87
+ tensor = tensor.float()
88
+ return float(func(tensor))
89
+ return float(func(tensor).asnumpy())
90
+
91
+ @classmethod
92
+ def tensor_max(cls, tensor):
93
+ return cls.process_tensor(tensor, lambda x: x.max())
94
+
95
+ @classmethod
96
+ def tensor_min(cls, tensor):
97
+ return cls.process_tensor(tensor, lambda x: x.min())
98
+
99
+ @classmethod
100
+ def tensor_mean(cls, tensor):
101
+ return cls.process_tensor(tensor, lambda x: x.mean())
102
+
103
+ @classmethod
104
+ def tensor_norm(cls, tensor):
105
+ return cls.process_tensor(tensor, lambda x: x.norm())
106
+
107
+ @classmethod
108
+ def save_tensor(cls, tensor, filepath):
109
+ if cls.fmk == Const.PT_FRAMEWORK:
110
+ tensor_npy = tensor.cpu().detach().float().numpy()
111
+ else:
112
+ tensor_npy = tensor.asnumpy()
113
+ save_npy(tensor_npy, filepath)
114
+
115
+ @classmethod
116
+ def dtype(cls, dtype_str):
117
+ if dtype_str not in cls.supported_dtype_list:
118
+ raise Exception(f"{dtype_str} is not supported by adapter, not in {cls.supported_dtype_list}")
119
+ return getattr(cls.framework, dtype_str)
120
+
121
+ @classmethod
122
+ def named_parameters(cls, module):
123
+ if cls.fmk == Const.PT_FRAMEWORK:
124
+ if not isinstance(module, cls.framework.nn.Module):
125
+ raise Exception(f"{module} is not a torch.nn.Module")
126
+ return module.named_parameters()
127
+ if not isinstance(module, cls.framework.nn.Cell):
128
+ raise Exception(f"{module} is not a mindspore.nn.Cell")
129
+ return module.parameters_and_names()
130
+
131
+ @classmethod
132
+ def register_forward_pre_hook(cls, module, hook, with_kwargs=False):
133
+ if cls.fmk == Const.PT_FRAMEWORK:
134
+ if not isinstance(module, cls.framework.nn.Module):
135
+ raise Exception(f"{module} is not a torch.nn.Module")
136
+ module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
137
+ else:
138
+ if not isinstance(module, cls.framework.nn.Cell):
139
+ raise Exception(f"{module} is not a mindspore.nn.Cell")
140
+ original_construct = module.construct
141
+
142
+ @functools.wraps(original_construct)
143
+ def new_construct(*args, **kwargs):
144
+ if with_kwargs:
145
+ hook(module, args, kwargs)
146
+ else:
147
+ hook(module, args)
148
+ return original_construct(*args, **kwargs)
149
+
150
+ module.construct = new_construct
151
+
152
+ @classmethod
153
+ def load_checkpoint(cls, path, to_cpu=True, weights_only=True):
154
+ check_file_or_directory_path(path)
155
+ if cls.fmk == Const.PT_FRAMEWORK:
156
+ try:
157
+ if to_cpu:
158
+ return cls.framework.load(path, map_location=cls.framework.device("cpu"), weights_only=weights_only)
159
+ else:
160
+ return cls.framework.load(path, weights_only=weights_only)
161
+ except Exception as e:
162
+ raise RuntimeError(f"load pt file {path} failed: {e}") from e
163
+ return mindspore.load_checkpoint(path)
164
+
165
+ @classmethod
166
+ def asnumpy(cls, tensor):
167
+ if cls.fmk == Const.PT_FRAMEWORK:
168
+ return tensor.float().numpy()
169
+ return tensor.float().asnumpy()