mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -135,7 +135,8 @@ class PathManager:
135
135
  msg = f"The path does not exist: {path}"
136
136
  raise ProfilerPathErrorException(msg)
137
137
  if os.name != 'nt' and os.stat(path).st_uid != os.getuid():
138
- msg = f"Path {path} owner[{os.stat(path).st_uid}] does not match the current user[{os.getuid()}]."
138
+ msg = (f"Path {path} owner[{os.stat(path).st_uid}] does not match the current user[{os.getuid()}]."
139
+ f"Please execute chown -R $(id -un) {path}")
139
140
  raise ProfilerPathErrorException(msg)
140
141
 
141
142
  @classmethod
@@ -153,7 +154,7 @@ class PathManager:
153
154
  msg = f"Invalid path is a soft link: {path}"
154
155
  raise ProfilerPathErrorException(msg)
155
156
  if not os.access(path, os.W_OK):
156
- msg = f"The path writeable permission check failed: {path}"
157
+ msg = f"The path writeable permission check failed: {path}. Please execute chmod -R 755 {path}"
157
158
  raise ProfilerPathErrorException(msg)
158
159
 
159
160
  @classmethod
@@ -171,7 +172,7 @@ class PathManager:
171
172
  msg = f"Invalid path is a soft link: {path}"
172
173
  raise ProfilerPathErrorException(msg)
173
174
  if not os.access(path, os.R_OK):
174
- msg = f"The path readable permission check failed: {path}"
175
+ msg = f"The path readable permission check failed: {path}. Please execute chmod -R 755 {path}"
175
176
  raise ProfilerPathErrorException(msg)
176
177
 
177
178
  @classmethod
@@ -246,26 +247,6 @@ class PathManager:
246
247
  except Exception as err:
247
248
  raise ProfilerPathErrorException(f"Failed to make directory: {path}, err: {err}") from err
248
249
 
249
- @classmethod
250
- def create_file_safety(cls, path: str):
251
- """
252
- Function Description:
253
- create file safety
254
- Parameter:
255
- path: the file to remove
256
- Exception Description:
257
- when invalid data throw exception
258
- """
259
- if os.path.islink(path):
260
- raise RuntimeError(f"Failed to create file: {path}, is a soft link")
261
- if os.path.exists(path):
262
- logger.warning("File already exists: %s", path)
263
- return
264
- try:
265
- os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY))
266
- except Exception as err:
267
- raise RuntimeError(f"Failed to create file: {path}, err: {err}") from err
268
-
269
250
  @classmethod
270
251
  def _input_path_common_check(cls, path: str):
271
252
  """
@@ -399,6 +380,57 @@ class PathManager:
399
380
  return False
400
381
  if os.name == 'nt':
401
382
  return False
402
- if os.stat(path).st_uid == 0 or os.stat(path).st_uid == os.getuid():
383
+ if os.stat(lib_path).st_uid == 0 or os.stat(lib_path).st_uid == os.getuid():
403
384
  return True
404
385
  return False
386
+
387
+ @classmethod
388
+ def check_path_is_other_writable(cls, path):
389
+ """Check whether the file or directory in the specified path has writable permissions for others."""
390
+ file_stat = os.stat(path)
391
+ if file_stat.st_mode & (stat.S_IWGRP | stat.S_IWOTH):
392
+ msg = (f"File path {path} has group or others writable permissions, which is not allowed."
393
+ f"Please execute chmod -R 755 {path}")
394
+ raise ProfilerPathErrorException(msg)
395
+
396
+ @classmethod
397
+ def check_path_is_owner_or_root(cls, path):
398
+ """Check path is owner or root."""
399
+ if not os.path.exists(path):
400
+ msg = f"The path does not exist: {path}"
401
+ raise ProfilerPathErrorException(msg)
402
+ file_stat = os.stat(path)
403
+ current_uid = os.getuid()
404
+ file_uid = file_stat.st_uid
405
+ if file_uid not in (0, current_uid):
406
+ return False
407
+ return True
408
+
409
+ @classmethod
410
+ def check_path_is_executable(cls, path):
411
+ """Check path is executable"""
412
+ return os.access(path, os.X_OK)
413
+
414
+ @classmethod
415
+ def check_path_is_readable(cls, path):
416
+ """Check path is readable"""
417
+ if os.path.islink(path):
418
+ msg = f"Invalid path is a soft link: {path}"
419
+ raise ProfilerPathErrorException(msg)
420
+ if not os.access(path, os.R_OK):
421
+ msg = f"The path readable permission check failed: {path}."
422
+ raise ProfilerPathErrorException(msg)
423
+
424
+ @classmethod
425
+ def walk_with_depth(cls, path, *args, max_depth=10, **kwargs):
426
+ """walk path depth"""
427
+ if not isinstance(path, str):
428
+ return
429
+ base_depth = path.count(os.sep)
430
+ if path.endswith(os.sep):
431
+ base_depth -= 1
432
+ for root, dirs, files in os.walk(path, *args, **kwargs):
433
+ if root.count(os.sep) - base_depth > max_depth:
434
+ dirs.clear()
435
+ continue
436
+ yield root, dirs, files
@@ -24,7 +24,6 @@ from typing import (
24
24
  )
25
25
 
26
26
  from mindspore.communication.management import GlobalComm
27
- from mindspore.communication.management import get_local_rank
28
27
  from mindspore.communication.management import get_rank
29
28
  from mindspore.profiler.common.constant import (
30
29
  DeviceTarget,
@@ -44,6 +43,7 @@ from mindspore import context
44
43
  from mindspore import log as logger
45
44
  from mindspore.profiler.common.profiler_info import ProfilerInfo
46
45
  from mindspore.profiler.experimental_config import _ExperimentalConfig
46
+ from mindspore.profiler.common.util import get_device_id
47
47
 
48
48
 
49
49
  @Singleton
@@ -488,17 +488,7 @@ class ProfilerContext:
488
488
  """
489
489
  Initialize the device ID.
490
490
  """
491
- self._device_id = str(context.get_context("device_id"))
492
-
493
- if not self._device_id or not self._device_id.isdigit():
494
- if GlobalComm.INITED and self._device_target == DeviceTarget.NPU.value:
495
- self._device_id = str(get_local_rank())
496
- else:
497
- self._device_id = os.getenv("DEVICE_ID")
498
-
499
- if not self._device_id or not self._device_id.isdigit():
500
- self._device_id = "0"
501
- logger.warning("Fail to get DEVICE_ID, use 0 instead.")
491
+ self._device_id = get_device_id()
502
492
 
503
493
  def _init_rank_id(self) -> None:
504
494
  """
@@ -97,13 +97,14 @@ class ProfilerInfo:
97
97
  Load time parameters from msprof profile and host start log.
98
98
  This method should be called before TimeConverter.init_parameters.
99
99
  """
100
+ msprof_info = MsprofCmdTool(msprof_profile_path).get_msprof_info()
100
101
  if not msprof_profile_path or not msprof_profile_host_path:
101
102
  raise ValueError(
102
103
  "msprof_profile_path and msprof_profile_host_path must be provided"
103
104
  )
104
105
  self._read_host_start_log(msprof_profile_host_path)
105
106
  self._read_start_info(msprof_profile_host_path)
106
- self._get_freq_from_msprof(msprof_profile_path)
107
+ self._get_freq_from_msprof(msprof_info)
107
108
 
108
109
  @property
109
110
  def time_parameters(self) -> Dict[str, Any]:
@@ -237,7 +238,7 @@ class ProfilerInfo:
237
238
  self._collection_time_begin * self.US_TO_NS - self._clock_monotonic_raw_info
238
239
  )
239
240
 
240
- def _get_freq_from_msprof(self, msprof_profile_path: str) -> None:
241
+ def _get_freq_from_msprof(self, msprof_info: str) -> None:
241
242
  """
242
243
  Get frequency from get_msprof_info.py script
243
244
 
@@ -250,7 +251,6 @@ class ProfilerInfo:
250
251
  }
251
252
  }
252
253
  """
253
- msprof_info = MsprofCmdTool(msprof_profile_path).get_msprof_info()
254
254
 
255
255
  if not isinstance(msprof_info, dict):
256
256
  raise RuntimeError("msprof_info must be a dictionary")
@@ -143,6 +143,19 @@ class ProfilerPathManager:
143
143
  new_file_path = os.path.join(self._prof_ctx.ascend_profiler_output_path, new_file_name)
144
144
  shutil.move(db_file, new_file_path)
145
145
 
146
+ def remove_db_file(self):
147
+ """
148
+ Remove the db file in the output path.
149
+ """
150
+ if not self._prof_ctx.msprof_profile_output_path:
151
+ return
152
+ db_files = glob.glob(os.path.join(
153
+ os.path.dirname(self._prof_ctx.msprof_profile_output_path),
154
+ 'msprof*.db'
155
+ ))
156
+ for db_file in db_files:
157
+ if os.path.isfile(db_file):
158
+ os.remove(db_file)
146
159
 
147
160
  def create_output_path(self):
148
161
  """
@@ -25,10 +25,15 @@ import re
25
25
  import shutil
26
26
  import stat
27
27
 
28
+ from mindspore.communication.management import GlobalComm
29
+ from mindspore.communication.management import get_local_rank
30
+ from mindspore import context
28
31
  from mindspore import log as logger
29
32
  from mindspore.profiler.common.path_manager import PathManager
30
33
  from mindspore.profiler.common.exceptions.exceptions import ProfilerPathErrorException
31
34
 
35
+ from mindspore.profiler.common.constant import DeviceTarget
36
+
32
37
 
33
38
  def no_exception_func(
34
39
  default_ret: Any = None,
@@ -75,12 +80,16 @@ def get_cann_version():
75
80
  ascend_home_path = os.environ.get("ASCEND_HOME_PATH", "")
76
81
  cann_version = "not known"
77
82
  try:
78
- PathManager.check_directory_path_readable(os.path.realpath(ascend_home_path))
79
- for dirpath, _, filenames in os.walk(os.path.realpath(ascend_home_path)):
83
+ if not PathManager.check_path_is_owner_or_root(ascend_home_path):
84
+ raise PermissionError(f"PermissionError, CANN package user id: {os.stat(ascend_home_path).st_uid}, "
85
+ f"current user id: {os.getuid()}. "
86
+ f"Ensure CANN package user id and current user id consistency")
87
+ PathManager.check_path_is_readable(os.path.realpath(ascend_home_path))
88
+ for dirpath, _, filenames in PathManager.walk_with_depth(os.path.realpath(ascend_home_path)):
80
89
  install_files = [file for file in filenames if re.match(r"ascend_.{1,20}_install\.info", file)]
81
90
  if install_files:
82
91
  filepath = os.path.realpath(os.path.join(dirpath, install_files[0]))
83
- PathManager.check_directory_path_readable(filepath)
92
+ PathManager.check_path_is_readable(filepath)
84
93
  with open(filepath, "r") as f:
85
94
  for line in f:
86
95
  if line.find("version") != -1:
@@ -441,6 +450,24 @@ def get_newest_file(file_list):
441
450
  return newest_file_list
442
451
 
443
452
 
453
+ def get_device_id():
454
+ """
455
+ Get device ID.
456
+ """
457
+ device_id = str(context.get_context("device_id"))
458
+
459
+ if not device_id or not device_id.isdigit():
460
+ if GlobalComm.INITED and context.get_context("device_target") == DeviceTarget.NPU.value:
461
+ device_id = str(get_local_rank())
462
+ else:
463
+ device_id = os.getenv("DEVICE_ID")
464
+
465
+ if not device_id or not device_id.isdigit():
466
+ logger.warning("Fail to get DEVICE_ID, use 0 instead.")
467
+ device_id = "0"
468
+ return device_id
469
+
470
+
444
471
  class ProfilerPathManager:
445
472
  """A path manager to manage profiler path"""
446
473
 
@@ -84,7 +84,8 @@ class _ExperimentalConfig:
84
84
  HCCS data, PCIe data, and Stars Chip Trans. Default: ``False``.
85
85
  host_sys (list, optional): Collect the data of system call classes on the host side.
86
86
  Default: ``[]``, indicating that system class data on the host side is not collected.
87
- You need to set `start_profile` of :class:`mindspore.profiler.profile` to ``False``.When collecting DISK or
87
+ You need to set `start_profile` of :class:`mindspore.profiler.profile` to ``False``.Currently, only
88
+ the **root user** supports collecting DISK or OSRT data, when collecting DISK or
88
89
  OSRT data, it is necessary to install the iotop, perf, and ltrace third-party tools in advance.
89
90
  For detailed steps, please refer to `Installing Third-party Tools
90
91
  <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/T&ITools/Profiling/atlasprofiling_16_0136.
@@ -36,6 +36,8 @@ from mindspore.profiler.platform.base_profiler import BaseProfiler
36
36
  from mindspore.profiler.common.profiler_path_manager import ProfilerPathManager
37
37
  from mindspore.profiler.common.profiler_info import ProfilerInfo
38
38
  from mindspore.profiler.common.process_pool import MultiProcessPool
39
+ from mindspore.profiler.common.constant import MsprofModeName
40
+ from mindspore.profiler.common.util import no_exception_func
39
41
  from mindspore.profiler.analysis.task_manager import TaskManager
40
42
  from mindspore.profiler.analysis.time_converter import TimeConverter
41
43
  from mindspore.profiler.analysis.parser.ascend_cann_parser import AscendMsprofParser
@@ -58,6 +60,7 @@ from mindspore.profiler.analysis.viewer.ms_operator_details_viewer import MsOper
58
60
  from mindspore.profiler.common.util import print_msg_with_pid
59
61
  from mindspore.profiler.common.log import ProfilerLogger
60
62
  from mindspore.profiler.mstx import Mstx
63
+ from mindspore.profiler.common.util import get_device_id
61
64
 
62
65
 
63
66
  @PROFILERS.register_module(DeviceTarget.NPU.value)
@@ -68,6 +71,7 @@ class NpuProfiler(BaseProfiler):
68
71
 
69
72
  def __init__(self) -> None:
70
73
  super().__init__()
74
+ self._is_env_not_valid = self._is_environment_not_valid()
71
75
  self._prof_ctx = ProfilerContext()
72
76
  self._prof_info = ProfilerInfo()
73
77
  self._prof_path_mgr = ProfilerPathManager()
@@ -78,7 +82,7 @@ class NpuProfiler(BaseProfiler):
78
82
  # initialize profiler backend
79
83
  self._profiler.init(
80
84
  self._prof_ctx.ascend_ms_dir,
81
- int(self._prof_ctx.device_id),
85
+ int(get_device_id()),
82
86
  json.dumps(self._prof_ctx.npu_profiler_params),
83
87
  )
84
88
  self._logger.info("NpuProfiler init profiler backend params %s",
@@ -97,6 +101,8 @@ class NpuProfiler(BaseProfiler):
97
101
 
98
102
  def start(self) -> None:
99
103
  """Start profiling."""
104
+ if self._is_env_not_valid:
105
+ return
100
106
  self._logger.info("NpuProfiler start.")
101
107
 
102
108
  Mstx.enable = self._prof_ctx.npu_profiler_params.get("mstx", False)
@@ -115,6 +121,8 @@ class NpuProfiler(BaseProfiler):
115
121
 
116
122
  def stop(self) -> None:
117
123
  """Stop profiling."""
124
+ if self._is_env_not_valid:
125
+ return
118
126
  self._logger.info("NpuProfiler stop.")
119
127
 
120
128
  Mstx.enable = False
@@ -149,16 +157,29 @@ class NpuProfiler(BaseProfiler):
149
157
 
150
158
  def analyse(self, **kwargs) -> None:
151
159
  """Analyse the profiling data."""
160
+ if self._is_env_not_valid:
161
+ return
152
162
  self._logger.info("NpuProfiler analyse.")
153
163
 
154
164
  NPUProfilerAnalysis.online_analyse(async_mode=kwargs.get('async_mode'))
155
165
 
156
166
  def finalize(self) -> None:
157
167
  """Finalize profiling data."""
168
+ if self._is_env_not_valid:
169
+ return
158
170
  self._logger.info("NpuProfiler finalize.")
159
171
  if self._profiler:
160
172
  self._profiler.finalize()
161
173
 
174
+ @staticmethod
175
+ def _is_environment_not_valid() -> bool:
176
+ # check msprof dynamic environment variable
177
+ if os.getenv(MsprofModeName.MSPROF_DYNAMIC_ENV) is not None:
178
+ logger.error(f"The environment variable '{MsprofModeName.MSPROF_DYNAMIC_ENV}' has been set."
179
+ f"Please execute 'unset {MsprofModeName.MSPROF_DYNAMIC_ENV}'.")
180
+ return True
181
+ return False
182
+
162
183
 
163
184
  class NPUProfilerAnalysis:
164
185
  """
@@ -166,6 +187,7 @@ class NPUProfilerAnalysis:
166
187
  """
167
188
 
168
189
  @classmethod
190
+ @no_exception_func()
169
191
  def online_analyse(cls, async_mode: bool = False):
170
192
  """
171
193
  Online analysis for NPU
@@ -179,6 +201,7 @@ class NPUProfilerAnalysis:
179
201
  cls._run_tasks(**ProfilerContext().to_dict())
180
202
 
181
203
  @classmethod
204
+ @no_exception_func()
182
205
  def offline_analyse(
183
206
  cls,
184
207
  path: str,
@@ -262,11 +285,15 @@ class NPUProfilerAnalysis:
262
285
  task_mgr = cls._construct_task_mgr(**kwargs)
263
286
  task_mgr.run()
264
287
  ProfilerLogger.get_instance().info(json.dumps(task_mgr.cost_time, indent=4))
265
- activities = kwargs.get("activities")
266
- if activities and ProfilerActivity.NPU.value in activities:
267
- ProfilerPathManager().move_db_file()
268
- if kwargs.get("data_simplification") and ProfilerActivity.NPU.value in kwargs.get("activities"):
269
- ProfilerPathManager().simplify_data()
288
+ activities = kwargs.get("activities", [])
289
+ export_type = kwargs.get("export_type", [])
290
+ if ProfilerActivity.NPU.value in activities:
291
+ if ExportType.Db.value in export_type:
292
+ ProfilerPathManager().move_db_file()
293
+ else:
294
+ ProfilerPathManager().remove_db_file()
295
+ if kwargs.get("data_simplification"):
296
+ ProfilerPathManager().simplify_data()
270
297
 
271
298
  @classmethod
272
299
  def _construct_task_mgr(cls, **kwargs) -> TaskManager:
@@ -44,6 +44,23 @@ class EnvChecker(metaclass=ABCMeta):
44
44
  def check_version(self):
45
45
  pass
46
46
 
47
+ @staticmethod
48
+ def _concat_variable(env_name, env_value):
49
+ """concat value to the beginning of env specified by env_name"""
50
+ if not os.getenv(env_name, ""):
51
+ os.environ[env_name] = env_value
52
+ else:
53
+ paths = os.environ[env_name].split(':')
54
+ if paths and paths[0] == env_value:
55
+ return
56
+ if env_value not in paths:
57
+ os.environ[env_name] = env_value + ':' + os.environ[env_name]
58
+ else:
59
+ # move env_value to beginning
60
+ new_paths = [p for p in paths if p != env_value]
61
+ new_paths.insert(0, env_value)
62
+ os.environ[env_name] = ':'.join(new_paths)
63
+
47
64
 
48
65
  class CPUEnvChecker(EnvChecker):
49
66
  """CPU environment check."""
@@ -61,10 +78,7 @@ class CPUEnvChecker(EnvChecker):
61
78
  """set env for cpu"""
62
79
  plugin_dir = os.path.dirname(self.library_path)
63
80
  akg_dir = os.path.join(plugin_dir, "plugin/cpu")
64
- if os.getenv('LD_LIBRARY_PATH'):
65
- os.environ['LD_LIBRARY_PATH'] = akg_dir + ":" + os.environ['LD_LIBRARY_PATH']
66
- else:
67
- os.environ['LD_LIBRARY_PATH'] = akg_dir
81
+ EnvChecker._concat_variable('LD_LIBRARY_PATH', akg_dir)
68
82
 
69
83
 
70
84
  class GPUEnvChecker(EnvChecker):
@@ -142,10 +156,7 @@ class GPUEnvChecker(EnvChecker):
142
156
  v_str = str(v.major) + "." + str(v.minor)
143
157
  plugin_dir = os.path.dirname(self.library_path)
144
158
  akg_dir = os.path.join(plugin_dir, "gpu" + v_str)
145
- if os.getenv('LD_LIBRARY_PATH'):
146
- os.environ['LD_LIBRARY_PATH'] = akg_dir + ":" + os.environ['LD_LIBRARY_PATH']
147
- else:
148
- os.environ['LD_LIBRARY_PATH'] = akg_dir
159
+ EnvChecker._concat_variable('LD_LIBRARY_PATH', akg_dir)
149
160
  os.environ['CUDA_CACHE_MAXSIZE'] = "4000000000"
150
161
 
151
162
  def _get_bin_path(self, bin_name):
@@ -258,7 +269,7 @@ class AscendEnvChecker(EnvChecker):
258
269
 
259
270
  def __init__(self, library_path):
260
271
  self.library_path = library_path
261
- self.version = ["7.7", "7.8", "8.2"]
272
+ self.version = ["7.8", "8.2", "8.3"]
262
273
 
263
274
  # env
264
275
  self.path = os.getenv("PATH")
@@ -278,13 +289,6 @@ class AscendEnvChecker(EnvChecker):
278
289
  self.ascend_opp_kernel_path_check = "/opp_kernel"
279
290
  self.v = ""
280
291
 
281
- @staticmethod
282
- def _concat_variable(env_name, env_value):
283
- if os.getenv(env_name) is None:
284
- os.environ[env_name] = env_value
285
- else:
286
- os.environ[env_name] = env_value + ":" + os.environ[env_name]
287
-
288
292
  def check_custom_version(self):
289
293
  """custom op version check"""
290
294
 
@@ -359,11 +363,6 @@ class AscendEnvChecker(EnvChecker):
359
363
  """
360
364
  opp kernel install check
361
365
  """
362
- from mindspore._c_expression import MSContext
363
- soc_version = MSContext.get_instance().get_ascend_soc_version()
364
- if soc_version == "ascend310":
365
- return
366
-
367
366
  opp_kernel_path = self.ascend_opp_path.replace("opp", "opp_kernel")
368
367
  if not os.path.exists(opp_kernel_path):
369
368
  logger.critical("MindSpore relies on \"Ascend opp_kernel\" folder of the Ascend AI software package ("
@@ -389,7 +388,7 @@ class AscendEnvChecker(EnvChecker):
389
388
  os.environ['IGNORE_INFER_ERROR'] = "1"
390
389
  plugin_dir = os.path.dirname(self.library_path)
391
390
  akg_dir = os.path.join(plugin_dir, "ascend")
392
- AscendEnvChecker._concat_variable('LD_LIBRARY_PATH', akg_dir)
391
+ EnvChecker._concat_variable('LD_LIBRARY_PATH', akg_dir)
393
392
 
394
393
  self._check_env()
395
394
 
@@ -493,6 +492,85 @@ def _set_pb_env():
493
492
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
494
493
 
495
494
 
495
+ def _check_dir_path_safety(dir_path):
496
+ """Check safety of env directory path."""
497
+ if not os.path.exists(dir_path):
498
+ logger.warning(f"Path {dir_path} not exists.")
499
+ return False
500
+
501
+ if not os.path.isdir(dir_path):
502
+ logger.warning(f"Path {dir_path} is not a directory.")
503
+ return False
504
+
505
+ # check if path is suspicious
506
+ suspicious_patterns = [
507
+ "/tmp/", "/var/tmp/", "/dev/", "/proc/",
508
+ "\\temp\\", "\\windows\\temp\\",
509
+ "appdata", "local\\temp"
510
+ ]
511
+ lower_path = dir_path.lower()
512
+ for pattern in suspicious_patterns:
513
+ if pattern in lower_path:
514
+ logger.warning(f"Path {dir_path} is suspicious.")
515
+ return False
516
+
517
+ # check whether the path points to a system-critical directory
518
+ critical_dirs = [
519
+ "/bin", "/sbin", "/usr/bin", "/usr/sbin",
520
+ "/windows", "/system32", "c:\\windows"
521
+ ]
522
+ for critical_dir in critical_dirs:
523
+ if critical_dir in lower_path:
524
+ logger.warning(f"Path {dir_path} points to a system-critical directory.")
525
+ return False
526
+
527
+ return True
528
+
529
+
530
+ def check_cuda_path_safety(cuda_path):
531
+ """Check safety of cuda path."""
532
+ if not _check_dir_path_safety(cuda_path):
533
+ return False
534
+
535
+ expected_files = ["nvcc", "cudart.dll", "cudart.so"]
536
+ has_expected_content = False
537
+ for expected_file in expected_files:
538
+ if os.path.exists(os.path.join(cuda_path, "bin", expected_file)):
539
+ has_expected_content = True
540
+ break
541
+
542
+ if not has_expected_content:
543
+ logger.warning(f"The directory {cuda_path} does not contain the typical file structure of CUDA")
544
+ return False
545
+
546
+ return True
547
+
548
+
549
+ def check_cudnn_path_safety(cudnn_path):
550
+ """Check safety of cudnn path."""
551
+ if not _check_dir_path_safety(cudnn_path):
552
+ return False
553
+
554
+ expected_files = [
555
+ "include/cudnn.h",
556
+ "lib64/libcudnn.so", # Linux
557
+ "lib/libcudnn.dylib", # macOS
558
+ "lib/x64/cudnn.lib", # Windows
559
+ "bin/cudnn64_7.dll" # Windows
560
+ ]
561
+ found_files = []
562
+ for expected_file in expected_files:
563
+ full_path = os.path.join(cudnn_path, expected_file)
564
+ if os.path.exists(full_path):
565
+ found_files.append(expected_file)
566
+
567
+ if not found_files:
568
+ logger.warning(f"The directory {cudnn_path} does not contain the typical file structure of CUDNN")
569
+ return False
570
+
571
+ return True
572
+
573
+
496
574
  def _add_cuda_path():
497
575
  """add cuda path on windows."""
498
576
  if platform.system().lower() == 'windows':
@@ -500,15 +578,21 @@ def _add_cuda_path():
500
578
  if cuda_home is None:
501
579
  pass
502
580
  else:
581
+ if not check_cuda_path_safety(cuda_home):
582
+ logger.error(f"CUDA_PATH {cuda_home} is not safe, skip add cuda path.")
583
+ return
503
584
  cuda_bin_path = os.path.join(os.environ['CUDA_PATH'], 'bin')
504
585
  if sys.version_info >= (3, 8):
505
586
  os.add_dll_directory(cuda_bin_path)
506
587
  else:
507
588
  os.environ['PATH'] += os.pathsep + cuda_bin_path
508
- cudann_home = os.environ.get('CUDNN_HOME')
509
- if cudann_home is None:
589
+ cudnn_home = os.environ.get('CUDNN_HOME')
590
+ if cudnn_home is None:
510
591
  pass
511
592
  else:
593
+ if not check_cudnn_path_safety(cudnn_home):
594
+ logger.error(f"CUDNN_HOME {cuda_home} is not safe, skip add cudnn home.")
595
+ return
512
596
  cuda_home_bin_path = os.path.join(os.environ['CUDNN_HOME'], 'bin')
513
597
  if sys.version_info >= (3, 8):
514
598
  os.add_dll_directory(cuda_home_bin_path)
@@ -21,7 +21,8 @@ from mindspore.runtime.executor import launch_blocking, dispatch_threads_num, se
21
21
  set_kernel_launch_group, set_kernel_launch_capture
22
22
  from mindspore.runtime.memory import set_memory, memory_stats, memory_reserved, max_memory_reserved, empty_cache,\
23
23
  memory_replay, reset_peak_memory_stats, memory_summary, memory_allocated,\
24
- max_memory_allocated, reset_max_memory_reserved, reset_max_memory_allocated
24
+ max_memory_allocated, reset_max_memory_reserved, reset_max_memory_allocated,\
25
+ PluggableAllocator, MemPool, use_mem_pool
25
26
  from mindspore.runtime.stream import Stream, synchronize, set_cur_stream, current_stream, \
26
27
  default_stream, communication_stream, StreamCtx
27
28
  from mindspore.runtime.event import Event
@@ -33,7 +34,7 @@ __all__ = [
33
34
  "Stream", "communication_stream", "synchronize", "set_cur_stream", "current_stream", "default_stream", "StreamCtx",
34
35
  "set_memory", "memory_stats", "memory_reserved", "max_memory_reserved", "empty_cache", "memory_replay",
35
36
  "reset_peak_memory_stats", "memory_summary", "memory_allocated", "max_memory_allocated",
36
- "reset_max_memory_reserved", "reset_max_memory_allocated", "Event"
37
+ "reset_max_memory_reserved", "reset_max_memory_allocated", "Event", "PluggableAllocator", "MemPool", "use_mem_pool"
37
38
  ]
38
39
 
39
40
  __all__.sort()
@@ -196,7 +196,7 @@ def set_kernel_launch_group(thread_num=2, kernel_group_num=8):
196
196
 
197
197
 
198
198
  @args_type_check(enable_capture_graph=bool)
199
- def set_kernel_launch_capture(enable_capture_graph):
199
+ def set_kernel_launch_capture(enable_capture_graph, op_capture_skip=None):
200
200
  """
201
201
  In O0/O1 mode, the incremental inference scenario supports graph capture.
202
202
  By capturing the CPU-side operator dispatch behavior into a graph,
@@ -208,12 +208,20 @@ def set_kernel_launch_capture(enable_capture_graph):
208
208
  Args:
209
209
  enable_capture_graph (bool): Whether to enable graph capture.
210
210
  It can be turned on or off at any position in the script.
211
+ op_capture_skip (list): Custom non-captured operator names. Default: ``None``.
211
212
 
212
213
  Examples:
213
214
  >>> import mindspore as ms
214
- >>> ms.runtime.set_kernel_launch_capture(enable_capture_graph=True)
215
+ >>> op_capture_skip = ['matmul', 'addn']
216
+ >>> ms.runtime.set_kernel_launch_capture(True, op_capture_skip)
215
217
  """
216
218
  if RuntimeConf.get_instance().is_kernel_launch_group_configured():
217
219
  raise RuntimeError("The kernel launch group and kernel launch capture can not be set together")
218
220
 
219
- return RuntimeConf.get_instance().set_kernel_launch_capture(enable_capture_graph)
221
+ if op_capture_skip is None:
222
+ op_capture_skip = []
223
+
224
+ if not isinstance(op_capture_skip, list):
225
+ raise TypeError("op_capture_skip must be a list")
226
+
227
+ return RuntimeConf.get_instance().set_kernel_launch_capture(enable_capture_graph, op_capture_skip)