mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (458) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +65 -84
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +58 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +178 -53
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +377 -203
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +5 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +117 -131
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +67 -55
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +70 -24
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +27 -7
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +6 -46
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +429 -23
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +140 -104
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +491 -623
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +117 -110
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +4 -6
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/asgd.py +2 -0
  188. mindspore/nn/optim/lamb.py +1 -3
  189. mindspore/nn/optim/optimizer.py +1 -1
  190. mindspore/nn/optim/tft_wrapper.py +2 -3
  191. mindspore/nn/optim/thor.py +2 -2
  192. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  193. mindspore/nn/probability/distribution/exponential.py +2 -1
  194. mindspore/nn/probability/distribution/poisson.py +2 -1
  195. mindspore/nn/sparse/sparse.py +3 -3
  196. mindspore/nn/wrap/cell_wrapper.py +73 -42
  197. mindspore/nn/wrap/grad_reducer.py +37 -52
  198. mindspore/nn/wrap/loss_scale.py +72 -74
  199. mindspore/numpy/array_creations.py +7 -7
  200. mindspore/numpy/fft.py +1 -1
  201. mindspore/numpy/math_ops.py +5 -5
  202. mindspore/numpy/utils_const.py +1 -1
  203. mindspore/opencv_core452.dll +0 -0
  204. mindspore/opencv_imgcodecs452.dll +0 -0
  205. mindspore/opencv_imgproc452.dll +0 -0
  206. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  207. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  209. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  210. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  211. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  212. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  213. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
  214. mindspore/ops/auto_generate/gen_extend_func.py +27 -145
  215. mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
  216. mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
  217. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  218. mindspore/ops/composite/__init__.py +10 -0
  219. mindspore/ops/composite/base.py +9 -5
  220. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  221. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  222. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  223. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  224. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  225. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  226. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  227. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  228. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  229. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  230. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  231. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  232. mindspore/ops/function/__init__.py +4 -1
  233. mindspore/ops/function/_add_attr_func.py +11 -6
  234. mindspore/ops/function/array_func.py +19 -102
  235. mindspore/ops/function/debug_func.py +8 -5
  236. mindspore/ops/function/grad/grad_func.py +5 -13
  237. mindspore/ops/function/math_func.py +77 -572
  238. mindspore/ops/function/nn_func.py +46 -94
  239. mindspore/ops/function/other_func.py +4 -1
  240. mindspore/ops/function/random_func.py +44 -5
  241. mindspore/ops/function/vmap_func.py +2 -1
  242. mindspore/ops/functional.py +4 -4
  243. mindspore/ops/functional_overload.py +594 -18
  244. mindspore/ops/op_info_register.py +21 -0
  245. mindspore/ops/operations/__init__.py +16 -11
  246. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  247. mindspore/ops/operations/_inner_ops.py +14 -18
  248. mindspore/ops/operations/_sequence_ops.py +1 -1
  249. mindspore/ops/operations/array_ops.py +5 -51
  250. mindspore/ops/operations/comm_ops.py +186 -41
  251. mindspore/ops/operations/custom_ops.py +303 -177
  252. mindspore/ops/operations/debug_ops.py +59 -4
  253. mindspore/ops/operations/image_ops.py +13 -13
  254. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  255. mindspore/ops/operations/math_ops.py +8 -9
  256. mindspore/ops/operations/nn_ops.py +8 -40
  257. mindspore/ops/primitive.py +9 -20
  258. mindspore/ops/tensor_method.py +63 -15
  259. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  260. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  261. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  262. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  263. mindspore/ops_generate/common/base_generator.py +14 -0
  264. mindspore/ops_generate/common/gen_constants.py +8 -3
  265. mindspore/ops_generate/common/gen_utils.py +0 -19
  266. mindspore/ops_generate/common/op_proto.py +11 -4
  267. mindspore/ops_generate/common/template.py +88 -11
  268. mindspore/ops_generate/gen_ops.py +1 -1
  269. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  270. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  271. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  272. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  273. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  274. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  275. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  276. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  277. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  278. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  279. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  280. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  281. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  282. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  283. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  284. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  285. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  286. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  287. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  288. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  289. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  290. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  291. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  292. mindspore/parallel/_auto_parallel_context.py +16 -23
  293. mindspore/parallel/_cell_wrapper.py +113 -45
  294. mindspore/parallel/_parallel_serialization.py +4 -3
  295. mindspore/parallel/_ps_context.py +4 -6
  296. mindspore/parallel/_tensor.py +167 -12
  297. mindspore/parallel/_transformer/moe.py +1 -1
  298. mindspore/parallel/_transformer/transformer.py +17 -12
  299. mindspore/parallel/_utils.py +5 -11
  300. mindspore/parallel/auto_parallel.py +35 -14
  301. mindspore/parallel/checkpoint_convert.py +3 -3
  302. mindspore/parallel/checkpoint_transform.py +13 -7
  303. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  304. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  305. mindspore/parallel/cluster/run.py +48 -7
  306. mindspore/parallel/function/__init__.py +8 -1
  307. mindspore/parallel/function/reshard_func.py +12 -12
  308. mindspore/parallel/nn/__init__.py +15 -2
  309. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  310. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  311. mindspore/parallel/shard.py +10 -25
  312. mindspore/parallel/transform_safetensors.py +469 -174
  313. mindspore/pgodb140.dll +0 -0
  314. mindspore/pgort140.dll +0 -0
  315. mindspore/profiler/__init__.py +2 -1
  316. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  317. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  318. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  319. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  321. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  322. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  323. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  324. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  325. mindspore/profiler/analysis/task_manager.py +1 -1
  326. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  327. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  328. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  329. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  330. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  331. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  332. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  333. mindspore/profiler/common/constant.py +16 -0
  334. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  335. mindspore/profiler/common/path_manager.py +9 -0
  336. mindspore/profiler/common/profiler_context.py +50 -29
  337. mindspore/profiler/common/profiler_info.py +0 -16
  338. mindspore/profiler/common/profiler_meta_data.py +1 -0
  339. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  340. mindspore/profiler/common/profiler_output_path.py +23 -8
  341. mindspore/profiler/common/profiler_parameters.py +128 -35
  342. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  343. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  344. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  345. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  346. mindspore/profiler/dynamic_profiler.py +374 -338
  347. mindspore/profiler/envprofiler.py +42 -12
  348. mindspore/profiler/experimental_config.py +112 -7
  349. mindspore/profiler/mstx.py +33 -12
  350. mindspore/profiler/platform/__init__.py +2 -3
  351. mindspore/profiler/platform/cpu_profiler.py +10 -4
  352. mindspore/profiler/platform/npu_profiler.py +30 -20
  353. mindspore/profiler/profiler.py +218 -154
  354. mindspore/profiler/profiler_action_controller.py +65 -77
  355. mindspore/profiler/profiler_interface.py +2 -2
  356. mindspore/profiler/schedule.py +10 -4
  357. mindspore/rewrite/common/config.py +1 -0
  358. mindspore/rewrite/common/namer.py +1 -0
  359. mindspore/rewrite/common/namespace.py +1 -0
  360. mindspore/rewrite/node/node.py +31 -11
  361. mindspore/rewrite/parsers/assign_parser.py +1 -1
  362. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  363. mindspore/run_check/_check_version.py +7 -10
  364. mindspore/runtime/__init__.py +8 -6
  365. mindspore/runtime/event.py +10 -4
  366. mindspore/runtime/executor.py +87 -45
  367. mindspore/runtime/memory.py +31 -32
  368. mindspore/runtime/thread_bind_core.py +299 -165
  369. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  370. mindspore/swresample-4.dll +0 -0
  371. mindspore/swscale-6.dll +0 -0
  372. mindspore/tbbmalloc.dll +0 -0
  373. mindspore/tinyxml2.dll +0 -0
  374. mindspore/train/_utils.py +17 -7
  375. mindspore/train/amp.py +43 -23
  376. mindspore/train/callback/__init__.py +5 -5
  377. mindspore/train/callback/_callback.py +2 -1
  378. mindspore/train/callback/_checkpoint.py +4 -14
  379. mindspore/train/callback/_flops_collector.py +11 -7
  380. mindspore/train/callback/_landscape.py +0 -1
  381. mindspore/train/callback/_train_fault_tolerance.py +98 -21
  382. mindspore/train/data_sink.py +15 -6
  383. mindspore/train/dataset_helper.py +14 -5
  384. mindspore/train/model.py +133 -69
  385. mindspore/train/serialization.py +168 -126
  386. mindspore/train/summary/summary_record.py +13 -2
  387. mindspore/train/train_thor/model_thor.py +2 -2
  388. mindspore/turbojpeg.dll +0 -0
  389. mindspore/utils/__init__.py +3 -2
  390. mindspore/utils/dryrun.py +0 -6
  391. mindspore/utils/runtime_execution_order_check.py +163 -77
  392. mindspore/utils/sdc_detect.py +68 -0
  393. mindspore/utils/utils.py +14 -17
  394. mindspore/vcmeta.dll +0 -0
  395. mindspore/vcruntime140.dll +0 -0
  396. mindspore/vcruntime140_1.dll +0 -0
  397. mindspore/version.py +1 -1
  398. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  399. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
  400. mindspore/_deprecated/jit.py +0 -198
  401. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  402. mindspore/communication/_hccl_management.py +0 -297
  403. mindspore/experimental/es/embedding_service.py +0 -891
  404. mindspore/experimental/es/embedding_service_layer.py +0 -581
  405. mindspore/profiler/common/validator/__init__.py +0 -14
  406. mindspore/profiler/common/validator/validate_path.py +0 -84
  407. mindspore/profiler/parser/__init__.py +0 -14
  408. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  409. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  410. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  416. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  417. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  418. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  419. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  420. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  421. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  422. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  423. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  424. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  425. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  426. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  427. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  428. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  429. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  430. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  431. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  432. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  433. mindspore/profiler/parser/container.py +0 -229
  434. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  435. mindspore/profiler/parser/flops_parser.py +0 -531
  436. mindspore/profiler/parser/framework_enum.py +0 -111
  437. mindspore/profiler/parser/framework_parser.py +0 -464
  438. mindspore/profiler/parser/framework_struct.py +0 -61
  439. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  440. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  441. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  442. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  443. mindspore/profiler/parser/hccl_parser.py +0 -573
  444. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  445. mindspore/profiler/parser/integrator.py +0 -526
  446. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  447. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  448. mindspore/profiler/parser/minddata_parser.py +0 -186
  449. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  450. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  451. mindspore/profiler/parser/optime_parser.py +0 -250
  452. mindspore/profiler/parser/profiler_info.py +0 -213
  453. mindspore/profiler/parser/step_trace_parser.py +0 -666
  454. mindspore/utils/hooks.py +0 -81
  455. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  456. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  457. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  458. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2021-2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ import re
21
21
  import ast
22
22
  import hashlib
23
23
  import stat
24
+ import copy
24
25
  import inspect
25
26
  import importlib
26
27
  import platform
@@ -39,7 +40,7 @@ from mindspore.communication.management import get_rank, GlobalComm
39
40
  from ._ms_kernel import determine_variable_usage
40
41
  from ._custom_grad import autodiff_bprop
41
42
  from ._pyfunc_registry import add_pyfunc
42
- from ._custom_ops_utils import ExtensionBuilder
43
+ from ._custom_ops_utils import ExtensionBuilder, CustomCodeGenerator, CustomInfoGenerator
43
44
 
44
45
  if platform.system() != "Windows":
45
46
  import fcntl
@@ -72,13 +73,18 @@ def _get_cache_path():
72
73
  """
73
74
  cache_path = os.getenv('MS_COMPILER_CACHE_PATH')
74
75
  if cache_path is None:
75
- cache_path = "./akg_kernel_meta/"
76
+ cache_path = "./custom_kernel_meta/"
76
77
  elif cache_path[-1] != "/":
77
78
  cache_path = cache_path + "/"
78
79
 
79
80
  if not os.path.exists(cache_path):
80
81
  os.makedirs(cache_path, exist_ok=True)
81
82
 
83
+ # for distributed case, we create folders separately to avoid conflict
84
+ if GlobalComm.INITED:
85
+ cache_path = os.path.join(cache_path, "rank_" + str(get_rank()), "")
86
+ os.makedirs(cache_path, exist_ok=True)
87
+
82
88
  return cache_path
83
89
 
84
90
 
@@ -93,10 +99,6 @@ def _compile_aot(file):
93
99
  str, the path to the compiled library.
94
100
  """
95
101
  cache_path = _get_cache_path()
96
- # for distributed case, we create folders separately to avoid conflict
97
- if GlobalComm.INITED:
98
- cache_path = os.path.join(cache_path, "rank_" + str(get_rank()), "")
99
- os.makedirs(cache_path, exist_ok=True)
100
102
 
101
103
  res_path = importlib.util.find_spec("mindspore").origin
102
104
  find_pos = res_path.find("__init__.py")
@@ -109,12 +111,19 @@ def _compile_aot(file):
109
111
  func_path = cache_path + file_name + ".so"
110
112
  include_file = "{} -I{}".format(include_file, file[:file.rindex('/')])
111
113
 
114
+ if context.get_context("device_target") == "Ascend":
115
+ ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
116
+ ascend_include = os.path.join(ascend_cann_path, "include")
117
+ include_file = "{} -I{}".format(include_file, ascend_include)
118
+
119
+ include_file = include_file.split(" ")
112
120
  if func_path not in Custom.compiled_bin:
113
121
  Custom.compiled_bin.append(func_path)
114
122
 
115
123
  if file.endswith("cpp") or file.endswith("cc"):
116
124
  cmd = ["g++", "-std=c++17", "--shared", "-fPIC", "-D_GLIBCXX_USE_CXX11_ABI=0"]
117
- cmd += [include_file, "-o", func_path, file]
125
+ cmd += include_file
126
+ cmd += ["-o", func_path, file]
118
127
  elif file.endswith("cu"):
119
128
  cmd = ["nvcc"]
120
129
  cmd += ["--shared", "-Xcompiler", "-fPIC", "-O3", "-gencode", "arch=compute_70, code=sm_70"]
@@ -141,12 +150,13 @@ def _compile_aot(file):
141
150
  logger.warning("The current version of nvcc, V{}.{}.{}, might have unfixed issues with std string, "
142
151
  "which will lead to errors in aot custom op with attrs."
143
152
  "The version higher than V10.1.168 is recommended".format(v_major, v_mid, v_minor))
144
- cmd += [include_file, "-o", func_path, file]
153
+ cmd += include_file
154
+ cmd += ["-o", func_path, file]
145
155
  else:
146
156
  raise ValueError("The source file must be a cc/cpp/cu file, but get: {}".format(file))
147
157
 
148
158
  proc = subprocess.Popen(
149
- cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
159
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=False)
150
160
 
151
161
  (out, _) = proc.communicate(timeout=30)
152
162
 
@@ -224,14 +234,13 @@ class Custom(ops.PrimitiveWithInfer):
224
234
  <https://www.mindspore.cn/tutorials/en/master/custom_program/op_custom.html>`_ .
225
235
 
226
236
  .. warning::
227
- - This is an experimental API that is subject to change.
237
+ This is an experimental API that is subject to change.
228
238
 
229
239
  .. note::
230
240
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
231
241
 
232
242
  - "aot": supports ["GPU", "CPU", "Ascend"].
233
243
  - "pyfunc": supports ["CPU"].
234
- - "julia": supports ["CPU"].
235
244
 
236
245
  Args:
237
246
  func (Union[function, str]):
@@ -240,101 +249,84 @@ class Custom(ops.PrimitiveWithInfer):
240
249
  computation logic of a user defined operator.
241
250
 
242
251
  - str: If func is of str type, then str should be a path of file along with a function name.
243
- This could be used when func_type is "aot" or "julia".
244
-
245
- 1. for "aot":
246
-
247
- a) GPU/CPU platform.
248
- "aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as an
249
- operator. Users need to compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time,
250
- and offer the path of the file along with a function name.
251
-
252
- - "xxx.so" file generation:
253
-
254
- 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
255
- it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
256
-
257
- 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to
258
- compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
252
+ This could be used when func_type is "aot".
259
253
 
260
- - Define a "xxx.cc"/"xxx.cu" file:
254
+ for "aot":
261
255
 
262
- "aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share
263
- the same args. Typically, the function should be as:
256
+ a) GPU/CPU platform.
257
+ "aot" means ahead of time, in which case Custom directly launches user defined "xxx.so" file as an
258
+ operator. Users need to compile a handwriting "xxx.cu"/"xxx.cc" file into "xxx.so" ahead of time,
259
+ and offer the path of the file along with a function name.
264
260
 
265
- .. code-block::
261
+ - "xxx.so" file generation:
266
262
 
267
- int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
268
- void *stream, void *extra)
263
+ 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
264
+ it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
269
265
 
270
- Parameters:
266
+ 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to
267
+ compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
271
268
 
272
- - nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3 outputs,
273
- then nparam=5
274
- - params(void \*\*): a pointer to the array of inputs and outputs' pointer; the pointer type of
275
- inputs and outputs is void \* ; suppose the operator has 2 inputs and 3 outputs, then the first
276
- input's pointer is params[0] and the second output's pointer is params[3]
277
- - ndims(int \*): a pointer to the array of inputs and outputs' dimension num; suppose params[i] is a
278
- 1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.
279
- - shapes(int64_t \*\*): a pointer to the array of inputs and outputs' shapes(int64_t \*); the ith
280
- input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i] is a 2x3 tensor and
281
- params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.
282
- - dtypes(const char \*\*): a pointer to the array of inputs and outputs' types(const char \*);
283
- (ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64", "uint",
284
- "uint8", "uint16", "uint32", "uint64", "bool")
285
- - stream(void \*): stream pointer, only used in cuda file
286
- - extra(void \*): used for further extension
269
+ - Define a "xxx.cc"/"xxx.cu" file:
287
270
 
288
- Return Value(int):
271
+ "aot" is a cross-platform identity. The functions defined in "xxx.cc" or "xxx.cu" share
272
+ the same args. Typically, the function should be as:
289
273
 
290
- - 0: MindSpore will continue to run if this aot kernel is successfully executed
291
- - others: MindSpore will raise exception and exit
274
+ .. code-block::
292
275
 
293
- Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/
276
+ int func(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
277
+ void *stream, void *extra)
294
278
 
295
- - Use it in Custom:
279
+ Parameters:
296
280
 
297
- .. code-block::
281
+ - nparam(int): total number of inputs plus outputs; suppose the operator has 2 inputs and 3 outputs,
282
+ then nparam=5
283
+ - params(void \*\*): a pointer to the array of inputs and outputs' pointer; the pointer type of
284
+ inputs and outputs is void \* ; suppose the operator has 2 inputs and 3 outputs, then the first
285
+ input's pointer is params[0] and the second output's pointer is params[3]
286
+ - ndims(int \*): a pointer to the array of inputs and outputs' dimension num; suppose params[i] is a
287
+ 1024x1024 tensor and params[j] is a 77x83x4 tensor, then ndims[i]=2, ndims[j]=3.
288
+ - shapes(int64_t \*\*): a pointer to the array of inputs and outputs' shapes(int64_t \*); the ith
289
+ input's jth dimension's size is shapes[i][j](0<=j<ndims[i]); suppose params[i] is a 2x3 tensor and
290
+ params[j] is a 3x3x4 tensor, then shapes[i][0]=2, shapes[j][2]=4.
291
+ - dtypes(const char \*\*): a pointer to the array of inputs and outputs' types(const char \*);
292
+ (ex. "float32", "float16", "float", "float64", "int", "int8", "int16", "int32", "int64", "uint",
293
+ "uint8", "uint16", "uint32", "uint64", "bool")
294
+ - stream(void \*): stream pointer, only used in cuda file
295
+ - extra(void \*): used for further extension
298
296
 
299
- Custom(func="{dir_path}/{file_name}:{func_name}",...)
300
- (ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
301
- "aot"))
297
+ Return Value(int):
302
298
 
303
- b) Ascend platform.
304
- Before using Custom operators on the Ascend platform, users must first develop custom operators
305
- based on Ascend C and compile them. The complete development and usage process can refer to the
306
- tutorial `AOT-Type Custom Operators(Ascend)
307
- <https://www.mindspore.cn/tutorials/en/master/custom_program/operation/op_custom_ascendc.html>`_.
308
- By passing the name of the operator through the input parameter `func`, there are two usage methods
309
- based on the implementation of the infer function:
299
+ - 0: MindSpore will continue to run if this aot kernel is successfully executed
300
+ - others: MindSpore will raise exception and exit
310
301
 
311
- - Python infer: If the operator's infer function is implemented in Python, that is, the infer shape
312
- function is passed through the `out_shape` parameter, and the infer type is passed throuht the
313
- `out_dtype`, then the `func` should be specified as the operator name, for example,
314
- `func="CustomName"`.
315
- - C++ infer: If the operator's infer function is implemented through C++, then pass the path of the
316
- infer function implementation file in `func` and separate the operator name with `:`,
317
- for example: `func="add_custom_infer.cc:AddCustom"` .
302
+ Examples: see details in tests/st/ops/graph_kernel/custom/aot_test_files/
318
303
 
319
- 2. for "julia":
304
+ - Use it in Custom:
320
305
 
321
- Currently "julia" supports CPU(linux only) platform.
322
- For julia use JIT compiler, and julia support c api to call julia code.
323
- The Custom can directly launches user defined "xxx.jl" file as an operator.
324
- Users need to write a "xxx.jl" file which include modules and functions,
325
- and offer the path of the file along with a module name and function name.
306
+ .. code-block::
326
307
 
327
- Examples: see details in tests/st/ops/graph_kernel/custom/julia_test_files/
308
+ Custom(func="{dir_path}/{file_name}:{func_name}",...)
309
+ (ex. Custom(func="./reorganize.so:CustomReorganize", out_shape=[1], out_dtype=mstype.float32,
310
+ "aot"))
328
311
 
329
- - Use it in Custom:
312
+ b) Ascend platform.
313
+ Before using Custom operators on the Ascend platform, users must first develop custom operators
314
+ based on Ascend C and compile them. The complete development and usage process can refer to the
315
+ tutorial `AOT-Type Custom Operators(Ascend)
316
+ <https://www.mindspore.cn/tutorials/en/master/custom_program/operation/op_custom_ascendc.html>`_.
317
+ By passing the name of the operator through the input parameter `func`, there are two usage methods
318
+ based on the implementation of the infer function:
330
319
 
331
- .. code-block::
320
+ - Python infer: If the operator's infer function is implemented in Python, that is, the infer shape
321
+ function is passed through the `out_shape` parameter, and the infer type is passed throuht the
322
+ `out_dtype`, then the `func` should be specified as the operator name, for example,
323
+ `func="CustomName"`.
324
+ - C++ infer: If the operator's infer function is implemented through C++, then pass the path of the
325
+ infer function implementation file in `func` and separate the operator name with `:`,
326
+ for example: `func="add_custom_infer.cc:AddCustom"` .
332
327
 
333
- Custom(func="{dir_path}/{file_name}:{module_name}:{func_name}",...)
334
- (ex. Custom(func="./add.jl:Add:add", out_shape=[1], out_dtype=mstype.float32, "julia"))
335
-
336
- out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of
337
- `func`. Default: ``None`` .
328
+ out_shape (Union[function, list, tuple], optional): The output shape infer function or the value of output
329
+ shape of `func`. Default: ``None`` .
338
330
 
339
331
  If func has single output, then the value of output shape is a list or tuple of int.
340
332
 
@@ -344,8 +336,8 @@ class Custom(ops.PrimitiveWithInfer):
344
336
  The input can be None only when the func_type input is "hybrid". In this case, the automatic infer
345
337
  shape mechanic will be enabled.
346
338
 
347
- out_dtype (Union[function, :class:`mindspore.dtype`, tuple[:class:`mindspore.dtype`]]): The output data type
348
- infer function or the value of output data type of `func`. Default: ``None`` .
339
+ out_dtype (Union[function, :class:`mindspore.dtype`, tuple[:class:`mindspore.dtype`]], optional): The output
340
+ data type infer function or the value of output data type of `func`. Default: ``None`` .
349
341
 
350
342
  If func has single output, then the value of output shape is a `mindspore.dtype`.
351
343
 
@@ -355,13 +347,12 @@ class Custom(ops.PrimitiveWithInfer):
355
347
  The input can be None only when the func_type input is "hybrid". In this case, the automatic infer
356
348
  value mechanic will be enabled.
357
349
 
358
- func_type (str): The implementation type of `func`, should be one of
359
-
360
- [ ``"aot"`` , ``"pyfunc"`` , ``"julia"`` ].
350
+ func_type (str, optional): The implementation type of `func`, should be one of
351
+ [ ``"aot"`` , ``"pyfunc"``]. Default: ``"pyfunc"``.
361
352
 
362
- bprop (function): The back propagation function of `func`. Default: ``None`` .
363
- reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
364
- json format of type str or dict. The reg info specifies supported data types and formats of inputs and
353
+ bprop (function, optional): The back propagation function of `func`. Default: ``None`` .
354
+ reg_info (Union[str, dict, list, tuple], optional): Represents the registration information(reg info) of `func`
355
+ with json format of type str or dict. The reg info specifies supported data types and formats of inputs and
365
356
  outputs, attributes and target of `func`. Default: ``None`` .
366
357
 
367
358
  If reg info is a list or tuple, then each item should be with json format of type str or dict, which
@@ -419,11 +410,11 @@ class Custom(ops.PrimitiveWithInfer):
419
410
  op_path_in_cache = [] # Save paths for op functions created in the cached.
420
411
  custom_aot_warning = True # Flag to enable warnings about custom aot path white list
421
412
 
422
- def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
413
+ def __init__(self, func, out_shape=None, out_dtype=None, func_type="pyfunc", bprop=None, reg_info=None):
423
414
  super().__init__("Custom")
424
415
 
425
416
  self.supported_targets = [ASCEND, GPU, CPU]
426
- self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
417
+ self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc"]
427
418
  self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
428
419
  self.func = func
429
420
  self.func_type = func_type
@@ -435,10 +426,13 @@ class Custom(ops.PrimitiveWithInfer):
435
426
  self._is_ms_kernel = False
436
427
  self.out_shape = out_shape
437
428
  self.out_dtype = out_dtype
429
+ self.reg_info = reg_info
430
+ self.is_ascend_c = (context.get_context("device_target") == "Ascend" and self.func_type == "aot")
438
431
 
439
432
  self._check_platform()
440
433
  self._check_func()
441
- self._update_func_info(reg_info)
434
+ self._generate_reg_info()
435
+ self._update_func_info(self.reg_info)
442
436
  self.add_prim_attr("func_name", self.func_name)
443
437
  self.add_prim_attr("uniq_name", self.uniq_name)
444
438
  if self.func_type == HYBRID_TYPE:
@@ -450,25 +444,22 @@ class Custom(ops.PrimitiveWithInfer):
450
444
  add_pyfunc(func_id, self.func)
451
445
  self.add_prim_attr("fn_id", func_id)
452
446
 
453
- if self.out_shape is None and self.func_type == "aot":
454
- self.add_prim_attr("cpp_infer_shape", True)
455
- if self.out_dtype is None and self.func_type == "aot":
456
- self.add_prim_attr("cpp_infer_type", True)
457
- self.multi_output = (reg_info is not None and (len(reg_info.get("outputs", [])) > 1))
458
- self.add_prim_attr("multi_output", self.multi_output)
447
+ self._set_infer_flag()
448
+ self._set_multi_output_flag()
459
449
 
460
450
  self.bprop = bprop
461
451
  self.fake_output = False
462
452
  self.single_scalar_output = False
463
- if not self.out_dtype and not self.func_type == "pyfunc":
464
- self.fake_output = True
465
- elif not self.out_shape and self.func_type == "pyfunc":
466
- self.single_scalar_output = True
467
- self.add_prim_attr("fake_output", self.fake_output)
468
- self.add_prim_attr("single_scalar_output", self.single_scalar_output)
453
+ if self.func_type == "pyfunc":
454
+ if not self.out_dtype:
455
+ self.fake_output = True
456
+ elif not self.out_shape:
457
+ self.single_scalar_output = True
458
+ self.add_prim_attr("fake_output", self.fake_output)
459
+ self.add_prim_attr("single_scalar_output", self.single_scalar_output)
469
460
 
470
461
  # Register info
471
- self._register_info(reg_info)
462
+ self._register_info(self.reg_info)
472
463
 
473
464
  if func_type == "akg":
474
465
  self._set_akg_kernel_type()
@@ -479,11 +470,23 @@ class Custom(ops.PrimitiveWithInfer):
479
470
  self.add_prim_attr("func_type", self.func_type)
480
471
  self._update_attr()
481
472
 
482
- self.enable_pyboost = (context.get_context("device_target") == "Ascend" and self.func_type == "aot")
483
- if self.enable_pyboost:
473
+ if self.is_ascend_c:
484
474
  self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
485
475
  for key, value in super().get_attr_dict().items():
486
476
  self.custom_pyboost.add_prim_attr(key, value)
477
+ self._generate_get_workspace_size_func()
478
+
479
+ def _set_infer_flag(self):
480
+ """set cpp infer attr"""
481
+ if self.out_shape is None and self.func_type == "aot":
482
+ self.add_prim_attr("cpp_infer_shape", True)
483
+ if self.out_dtype is None and self.func_type == "aot":
484
+ self.add_prim_attr("cpp_infer_type", True)
485
+
486
+ def _set_multi_output_flag(self):
487
+ outputs = self.reg_info.get("outputs", []) if self.reg_info else []
488
+ self.multi_output = len(outputs) > 1 or (len(outputs) == 1 and outputs[0].get("paramType") == "dynamic")
489
+ self.add_prim_attr("multi_output", self.multi_output)
487
490
 
488
491
  def __infer__(self, *args):
489
492
  if callable(self.out_shape):
@@ -563,22 +566,6 @@ class Custom(ops.PrimitiveWithInfer):
563
566
  self.func_type = HYBRID_TYPE
564
567
  self._hybrid_func_analyser()
565
568
 
566
- def _check_julia_func(self):
567
- """Check the validity of julia func"""
568
- if not isinstance(self.func, str):
569
- raise TypeError("{}, 'func' must be of type str, but got {}".format(self.log_prefix, type(self.func)))
570
- if self.func.count(':') != 2:
571
- raise ValueError("{}, the format of 'func' must be file:module:func".format(self.log_prefix))
572
- source_file, module, func = self.func.split(':')
573
- with open(source_file, 'r') as f:
574
- jl = f.read()
575
- if 'module ' + module not in jl:
576
- raise Exception("{}, module {} is not found in source file {}!"
577
- .format(self.log_prefix, module, source_file))
578
- if 'function ' + func not in jl:
579
- raise Exception("{}, function {} is not found in source file {}!"
580
- .format(self.log_prefix, func, source_file))
581
-
582
569
  def _check_aot_func(self):
583
570
  """Check the source code and bin lib for aot type custom op"""
584
571
  if not isinstance(self.func, str):
@@ -621,8 +608,6 @@ class Custom(ops.PrimitiveWithInfer):
621
608
  if self.func_type == "aot":
622
609
  self._check_aot_func()
623
610
 
624
- elif self.func_type == "julia":
625
- self._check_julia_func()
626
611
  elif self.func_type == HYBRID_TYPE:
627
612
  if not hasattr(self.func, MS_KERNEL_FLAG):
628
613
  raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
@@ -776,6 +761,26 @@ class Custom(ops.PrimitiveWithInfer):
776
761
  if isinstance(item, dict) and item.get("value") is not None:
777
762
  self.add_prim_attr(item[KEY_NAME], item["value"])
778
763
 
764
+ def _convert_attr_to_input(self, ori_reg_info):
765
+ """convert attr to input"""
766
+ if not self.is_ascend_c or not ori_reg_info.get("attr"):
767
+ return ori_reg_info
768
+
769
+ reg_info = copy.deepcopy(ori_reg_info)
770
+ start_index = len(reg_info.get("inputs", []))
771
+ for i, attr_item in enumerate(reg_info.get("attr", [])):
772
+ new_input = {
773
+ 'index': start_index + i,
774
+ 'name': attr_item['name'],
775
+ 'paramType': attr_item['paramType']}
776
+ reg_info['inputs'].append(new_input)
777
+ for dtype_format_item in reg_info.get("dtype_format", []):
778
+ new_dtype_format_item = list(dtype_format_item)
779
+ new_dtype_format_item.insert(start_index + i, DataType.None_None)
780
+ reg_info['dtype_format'][reg_info['dtype_format'].index(dtype_format_item)] = new_dtype_format_item
781
+ reg_info['attr'] = []
782
+ return reg_info
783
+
779
784
  def _register_info(self, info):
780
785
  """Register reg_info."""
781
786
  reg_info = info
@@ -806,14 +811,15 @@ class Custom(ops.PrimitiveWithInfer):
806
811
  continue
807
812
  # Register
808
813
  reg_info = self._reformat_reg_info(reg_info, target)
809
- reg_info_str = json.dumps(reg_info)
814
+ new_reg_info = self._convert_attr_to_input(reg_info)
815
+ reg_info_str = json.dumps(new_reg_info)
810
816
  op_lib = Oplib()
811
817
  if not op_lib.reg_op(reg_info_str, self.imply_path):
812
818
  raise ValueError("{}, the registration information is registered failed. Use 'CustomRegOp' to "
813
819
  "generate the registration information, then pass it to 'reg_info' or use "
814
820
  "'custom_info_register' to bind it to 'func' if 'func' is a function."
815
821
  .format(self.log_prefix))
816
- self._save_attr(reg_info)
822
+ self._save_attr(new_reg_info)
817
823
  self._save_register_status(target)
818
824
 
819
825
  def _get_expanded_list(self, data):
@@ -918,7 +924,7 @@ class Custom(ops.PrimitiveWithInfer):
918
924
  return reg_info[IMPLY_TYPE]
919
925
  # Infer imply_type from func_type
920
926
  func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
921
- "julia": target, "aot": "BiSheng" if target == ASCEND else target}
927
+ "aot": "BiSheng" if target == ASCEND else target}
922
928
  return func_type_to_imply_type.get(self.func_type, AKG)
923
929
 
924
930
  def _save_attr(self, reg_info):
@@ -990,17 +996,6 @@ class Custom(ops.PrimitiveWithInfer):
990
996
  self.set_device(GPU)
991
997
  elif registered_targets == [CPU]:
992
998
  self.set_device(CPU)
993
- elif self.func_type == "julia":
994
- self.set_device(CPU)
995
- device_target = context.get_context('device_target')
996
- if device_target == CPU:
997
- pass
998
- elif device_target == GPU and registered_targets and registered_targets == [CPU]:
999
- logger.warning("{}, only supports CPU platform, but got registered target {}. "
1000
- "We will run it on CPU".format(self.log_prefix, registered_targets))
1001
- else:
1002
- raise ValueError("{}, only supports CPU platform, but got target {}."
1003
- .format(self.log_prefix, device_target))
1004
999
 
1005
1000
  def _update_attr(self):
1006
1001
  """Add input_names, attr_names, primitive_target to primitive's attr."""
@@ -1080,24 +1075,107 @@ class Custom(ops.PrimitiveWithInfer):
1080
1075
  if isinstance(arg_dtype, mstype.TensorType):
1081
1076
  arg_dtype = arg_dtype.element_type()
1082
1077
  fake_arg = np.zeros(arg["shape"]).astype(
1083
- mstype.dtype_to_nptype(arg_dtype))
1078
+ mstype._dtype_to_nptype(arg_dtype)) # pylint:disable=protected-access
1084
1079
  fake_input.append(fake_arg)
1085
1080
 
1086
1081
  fake_output = self.func(*fake_input)
1087
1082
 
1088
1083
  if hasattr(fake_output, 'shape'):
1089
1084
  infer_shape = fake_output.shape
1090
- infer_dtype = mstype.TensorType(mstype.pytype_to_dtype(fake_output.dtype))
1085
+ # pylint:disable=protected-access
1086
+ infer_dtype = mstype.TensorType(mstype._pytype_to_dtype(fake_output.dtype))
1091
1087
  else:
1092
1088
  infer_shape = (1,)
1093
- infer_dtype = mstype.pytype_to_dtype(fake_output.dtype)
1089
+ infer_dtype = mstype._pytype_to_dtype(fake_output.dtype) # pylint:disable=protected-access
1094
1090
 
1095
1091
  infer_value = Tensor(fake_output) if enable_infer_value else None
1096
1092
 
1097
1093
  return infer_shape, infer_dtype, infer_value
1098
1094
 
1095
+ def _generate_reg_info(self):
1096
+ if not self.is_ascend_c:
1097
+ return
1098
+ if self.reg_info is None:
1099
+ func_name, _ = self._split_func()
1100
+ if func_name.startswith("aclnn"):
1101
+ func_name = func_name[len("aclnn"):]
1102
+ reg_info_generator = CustomInfoGenerator(func_name)
1103
+ self.reg_info = reg_info_generator.generate_custom_reg_op()
1104
+
1105
+ def _split_func(self):
1106
+ func_list = self.func.split(":")
1107
+ func_path = ""
1108
+ if len(func_list) == 2:
1109
+ func_path = func_list[0]
1110
+ func_name = func_list[1]
1111
+ else:
1112
+ func_name = self.func
1113
+ return func_name, func_path
1114
+
1115
+ def _generate_get_worspace_size_func_by_types(self, aclnn_api_types):
1116
+ """generate custom GetWorkSpaceSize func by aclnn api types"""
1117
+ if not self.is_ascend_c:
1118
+ return
1119
+
1120
+ input_output_types = []
1121
+ if isinstance(aclnn_api_types, str):
1122
+ params = re.split(r',\s*', aclnn_api_types)
1123
+ for param in params:
1124
+ param = param.replace('const ', '')
1125
+ type_part = re.search(r'^\s*(\w+\s*\*+|\w+)', param).group(1)
1126
+ type_part = type_part.replace(' ', '')
1127
+ input_output_types.append(type_part)
1128
+ elif isinstance(aclnn_api_types, list):
1129
+ input_output_types = aclnn_api_types
1130
+ else:
1131
+ raise RuntimeError(f"Unsupported type: {type(aclnn_api_types)}, support type is list or string.")
1132
+
1133
+ func_name, _ = self._split_func()
1134
+ file_path = os.path.join(_get_cache_path(), func_name, func_name + "_callback.cc")
1135
+
1136
+ file_path = os.path.abspath(file_path)
1137
+ dir_path = os.path.dirname(file_path)
1138
+ os.makedirs(dir_path, exist_ok=True)
1139
+
1140
+ custom_builder = CustomCodeGenerator()
1141
+ callback_func = custom_builder.generate_callback_by_types(func_name, self.reg_info, input_output_types)
1142
+
1143
+ with open(file_path, 'w') as f:
1144
+ f.write(callback_func)
1145
+
1146
+ custom_callback_func_path = _compile_aot(file_path)
1147
+ custom_callback_func = custom_callback_func_path + ":" + func_name
1148
+ self.add_prim_attr("custom_callback_func", custom_callback_func)
1149
+ self.add_prim_attr("custom_inputs_type", input_output_types[:-2])
1150
+
1151
+ def _generate_get_workspace_size_func(self):
1152
+ """generate custom GetWorkSpaceSize func"""
1153
+ if not self.is_ascend_c:
1154
+ return
1155
+ func_name, _ = self._split_func()
1156
+ file_path = os.path.join(_get_cache_path(), func_name, func_name + "_callback.cc")
1157
+
1158
+ file_path = os.path.abspath(file_path)
1159
+ dir_path = os.path.dirname(file_path)
1160
+ os.makedirs(dir_path, exist_ok=True)
1161
+
1162
+ custom_info_generator = CustomInfoGenerator(func_name)
1163
+ api_types = custom_info_generator.get_aclnn_api_types()
1164
+ custom_builder = CustomCodeGenerator()
1165
+ if api_types == []:
1166
+ api_types = custom_builder.get_api_types_by_reg_info(self.reg_info)
1167
+
1168
+ callback_func = custom_builder.generate_callback_by_types(func_name, self.reg_info, api_types)
1169
+ with open(file_path, 'w') as f:
1170
+ f.write(callback_func)
1171
+
1172
+ custom_callback_func_path = _compile_aot(file_path)
1173
+ custom_callback_func = custom_callback_func_path + ":" + func_name
1174
+ self.add_prim_attr("custom_callback_func", custom_callback_func)
1175
+ self.add_prim_attr("custom_inputs_type", api_types[:-2])
1176
+
1099
1177
  def __call__(self, *args):
1100
- if self.enable_pyboost:
1178
+ if self.is_ascend_c:
1101
1179
  res = pyboost_custom_ext(self.custom_pyboost, [args])
1102
1180
  return res if self.multi_output else res[0]
1103
1181
  should_elim, output = self.check_elim(*args)
@@ -1130,6 +1208,15 @@ class CustomOpBuilder:
1130
1208
  ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
1131
1209
  kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
1132
1210
 
1211
+ - build_dir (str, optional): The directory used to generate the operator build files.
1212
+ If this argument is set, the provided path will be used directly.
1213
+ If not set, a subdirectory named after the operator's name will be created under the path specified by
1214
+ the environment variable `MS_COMPILER_CACHE_PATH` (defaulting to "./kernel_meta"), and the files will
1215
+ be placed in this subdirectory. Default: ``None``.
1216
+
1217
+ - enable_atb (bool, optional): Whether to call ATB (Ascend Transformer Boost) operator. If set to ``True``,
1218
+ the `backend` must be ``Ascend`` or left empty. Default: ``False``.
1219
+
1133
1220
  .. note::
1134
1221
  - If the `backend` argument is provided, additional default flags will be automatically added to
1135
1222
  the compilation and linking steps to support the operator's target backend. The default options
@@ -1149,9 +1236,7 @@ class CustomOpBuilder:
1149
1236
  ... )
1150
1237
  >>> my_ops = builder.load()
1151
1238
  """
1152
- _mindspore_path = None
1153
1239
  _loaded_ops = {}
1154
- _ms_code_base = None
1155
1240
 
1156
1241
  def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
1157
1242
  self.name = name
@@ -1160,11 +1245,25 @@ class CustomOpBuilder:
1160
1245
  self.include_paths = include_paths
1161
1246
  self.cflags = cflags
1162
1247
  self.ldflags = ldflags
1163
- if CustomOpBuilder._mindspore_path is None:
1164
- CustomOpBuilder._mindspore_path = os.path.dirname(os.path.abspath(ms.__file__))
1165
- CustomOpBuilder._ms_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include")
1248
+ self.build_dir = kwargs.get("build_dir")
1249
+ self.enable_atb = kwargs.get("enable_atb", False)
1250
+ self.debug_mode = kwargs.get("debug_mode", False)
1251
+ self._ms_path = os.path.dirname(os.path.abspath(ms.__file__))
1252
+ if self.enable_atb:
1253
+ if backend is not None and backend != "Ascend":
1254
+ raise ValueError("For 'CustomOpBuilder', when 'enable_atb' is set to True, the 'backend' must be "
1255
+ f"'Ascend' (or left implicit), but got '{backend}'")
1256
+ self.backend = "Ascend"
1166
1257
  if self.backend == "Ascend":
1167
- self.ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
1258
+ ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
1259
+ if not ascend_opp_path:
1260
+ raise ValueError("Environment variable 'ASCEND_OPP_PATH' must be set for Ascend backend.")
1261
+ self.ascend_cann_path = ascend_opp_path.split('opp')[0]
1262
+
1263
+ if self.enable_atb:
1264
+ self.atb_home_path = os.getenv("ATB_HOME_PATH")
1265
+ if not self.atb_home_path:
1266
+ raise ValueError("Environment variable 'ATB_HOME_PATH' must be set when 'enable_atb' is True.")
1168
1267
 
1169
1268
  def get_sources(self):
1170
1269
  """
@@ -1183,29 +1282,31 @@ class CustomOpBuilder:
1183
1282
  list[str], A list of include paths.
1184
1283
  """
1185
1284
  include_list = self.include_paths if self.include_paths is not None else []
1186
- include_list.append(CustomOpBuilder._mindspore_path)
1187
- include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include"))
1188
- include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party"))
1189
- include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/robin_hood_hashing"))
1190
- include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/securec/include"))
1285
+ include_list.append(self._ms_path)
1286
+ include_list.append(os.path.join(self._ms_path, "include"))
1287
+ include_list.append(os.path.join(self._ms_path, "include", "third_party"))
1288
+ include_list.append(os.path.join(self._ms_path, "include", "third_party", "robin_hood_hashing"))
1289
+ include_list.append(os.path.join(self._ms_path, "include", "third_party", "securec", "include"))
1191
1290
 
1192
1291
  if self.backend == "Ascend":
1193
1292
  include_list.append(os.path.join(self.ascend_cann_path, "include"))
1293
+ if self.enable_atb:
1294
+ include_list.append(os.path.join(self.atb_home_path, "include"))
1194
1295
  include_list += self._get_ms_inner_includes()
1195
1296
  return include_list
1196
1297
 
1197
1298
  def _get_ms_inner_includes(self):
1198
1299
  """include paths for inner module interface."""
1199
- ms_inner_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include", "mindspore")
1300
+ ms_inner_path = os.path.join(self._ms_path, "include", "mindspore")
1200
1301
  include_list = []
1201
- include_list.append(ms_inner_code_base + "/core/include")
1202
- include_list.append(ms_inner_code_base + "/core/mindrt/include")
1203
- include_list.append(ms_inner_code_base + "/core/mindrt")
1204
- include_list.append(ms_inner_code_base + "/ops")
1205
- include_list.append(ms_inner_code_base + "/ops/kernel/include")
1206
- include_list.append(ms_inner_code_base + "/ccsrc")
1207
- include_list.append(ms_inner_code_base + "/ccsrc/include")
1208
- include_list.append(ms_inner_code_base + "/ccsrc/minddata/mindrecord/include")
1302
+ include_list.append(os.path.join(ms_inner_path, "core", "include"))
1303
+ include_list.append(os.path.join(ms_inner_path, "core", "mindrt", "include"))
1304
+ include_list.append(os.path.join(ms_inner_path, "core", "mindrt"))
1305
+ include_list.append(os.path.join(ms_inner_path, "ops"))
1306
+ include_list.append(os.path.join(ms_inner_path, "ops", "kernel", "include"))
1307
+ include_list.append(os.path.join(ms_inner_path, "ccsrc"))
1308
+ include_list.append(os.path.join(ms_inner_path, "ccsrc", "include"))
1309
+ include_list.append(os.path.join(ms_inner_path, "ccsrc", "minddata", "mindrecord", "include"))
1209
1310
  return include_list
1210
1311
 
1211
1312
  def get_cflags(self):
@@ -1215,10 +1316,14 @@ class CustomOpBuilder:
1215
1316
  Returns:
1216
1317
  list[str], A list of C++ compiler flags.
1217
1318
  """
1218
- flags = ['-fstack-protector-all', '-fPIC', '-pie']
1219
- flags += ['-DENABLE_FAST_HASH_TABLE=1']
1319
+ flags = [f'-DMS_EXTENSION_NAME={self.name}', '-D_GLIBCXX_USE_CXX11_ABI=0', '-DENABLE_FAST_HASH_TABLE=1']
1320
+ flags += ['-std=c++17', '-fstack-protector-all', '-fPIC', '-pie']
1321
+ if self.debug_mode:
1322
+ flags.append('-g')
1220
1323
  if self.backend == "Ascend":
1221
1324
  flags.append('-DCUSTOM_ASCEND_OP')
1325
+ if self.enable_atb:
1326
+ flags.append('-DCUSTOM_ENABLE_ATB')
1222
1327
  if self.cflags is not None:
1223
1328
  flags.append(self.cflags)
1224
1329
  return flags
@@ -1230,18 +1335,27 @@ class CustomOpBuilder:
1230
1335
  Returns:
1231
1336
  list[str], A list of linker flags.
1232
1337
  """
1233
- flags = ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath', '-s']
1338
+ flags = ['-shared']
1339
+ flags += ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath']
1340
+ if not self.debug_mode:
1341
+ flags.append('-s') # strip
1234
1342
  flags += [
1235
- '-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib')),
1343
+ f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib'))}",
1236
1344
  '-lmindspore_core',
1237
1345
  '-lmindspore_ms_backend',
1238
- '-lmindspore_pynative'
1346
+ '-lmindspore_pynative',
1347
+ '-lmindspore_extension'
1239
1348
  ]
1240
1349
  if self.backend == "Ascend":
1241
- flags.append('-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib/plugin')))
1242
- flags.append('-L' + os.path.abspath(os.path.join(self.ascend_cann_path, "lib64")))
1350
+ flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin'))}")
1351
+ flags.append(f"-L{os.path.abspath(os.path.join(self.ascend_cann_path, 'lib64'))}")
1243
1352
  flags.append('-lascendcl')
1244
1353
  flags.append('-l:libmindspore_ascend.so.2')
1354
+ if self.enable_atb:
1355
+ flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin', 'ascend'))}")
1356
+ flags.append('-lmindspore_extension_ascend_atb')
1357
+ flags.append(f"-L{os.path.abspath(os.path.join(self.atb_home_path, 'lib'))}")
1358
+ flags.append('-latb')
1245
1359
  if self.ldflags is not None:
1246
1360
  flags.append(self.ldflags)
1247
1361
  return flags
@@ -1256,7 +1370,7 @@ class CustomOpBuilder:
1256
1370
  Returns:
1257
1371
  str, The path to the compiled module.
1258
1372
  """
1259
- return ExtensionBuilder().build(
1373
+ return ExtensionBuilder(self._get_build_directory()).build(
1260
1374
  module_name=self.name,
1261
1375
  sources=self.get_sources(),
1262
1376
  extra_include_paths=self.get_include_paths(),
@@ -1283,3 +1397,15 @@ class CustomOpBuilder:
1283
1397
  module = importlib.util.module_from_spec(spec)
1284
1398
  spec.loader.exec_module(module)
1285
1399
  return module
1400
+
1401
+ def _get_build_directory(self):
1402
+ """Get build directory."""
1403
+ if self.build_dir is None:
1404
+ build_root = os.path.realpath(os.getenv('MS_COMPILER_CACHE_PATH', "./kernel_meta"))
1405
+ self.build_dir = os.path.join(build_root, self.name)
1406
+ else:
1407
+ self.build_dir = os.path.realpath(self.build_dir)
1408
+ logger.info(f'Build {self.name} in directory {self.build_dir}')
1409
+ if not os.path.exists(self.build_dir):
1410
+ os.makedirs(self.build_dir, exist_ok=True)
1411
+ return self.build_dir