mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ import numpy as np
29
29
  import mindspore as ms
30
30
  from mindspore._c_expression import Oplib, typing
31
31
  from mindspore._c_expression import pyboost_custom_ext
32
- from mindspore.common._stub_tensor import _convert_stub
33
32
  from mindspore import context
34
33
  from mindspore.common import Tensor
35
34
  from mindspore.common import dtype as mstype
@@ -40,6 +39,7 @@ from mindspore.communication.management import get_rank, GlobalComm
40
39
  from ._ms_kernel import determine_variable_usage
41
40
  from ._custom_grad import autodiff_bprop
42
41
  from ._pyfunc_registry import add_pyfunc
42
+ from ._custom_ops_utils import ExtensionBuilder
43
43
 
44
44
  if platform.system() != "Windows":
45
45
  import fcntl
@@ -183,10 +183,16 @@ class _CustomExt(ops.PrimitiveWithInfer):
183
183
 
184
184
  infer_value = None
185
185
  if infer_shape is None:
186
- logger.warning("'out_shape' is None. Add a placeholder instead. "
187
- "A CPP version of infer shape function is required "
188
- "in this case.")
186
+ logger.debug("'out_shape' is None. Add a placeholder instead. "
187
+ "A CPP version of infer shape function is required "
188
+ "in this case.")
189
189
  infer_shape = (1,)
190
+ if infer_dtype is None:
191
+ logger.debug("'out_dtype' is None. Add a placeholder instead. "
192
+ "A CPP version of infer type function is required "
193
+ "in this case.")
194
+ infer_dtype = ms.float16
195
+
190
196
  # after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
191
197
  if not isinstance(infer_shape, (tuple, list)):
192
198
  raise TypeError("'out_shape' must be one of [tuple, list, function], but got {}".format(type(infer_shape)))
@@ -215,7 +221,7 @@ class Custom(ops.PrimitiveWithInfer):
215
221
  function if needed. Then these `Custom` objects can be directly used in neural networks.
216
222
  Detailed description and introduction of user-defined operators, including correct writing of parameters,
217
223
  please refer to `Custom Operators Tutorial
218
- <https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_ .
224
+ <https://www.mindspore.cn/tutorials/en/master/custom_program/op_custom.html>`_ .
219
225
 
220
226
  .. warning::
221
227
  - This is an experimental API that is subject to change.
@@ -223,8 +229,6 @@ class Custom(ops.PrimitiveWithInfer):
223
229
  .. note::
224
230
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
225
231
 
226
- - "hybrid": supports ["GPU", "CPU"].
227
- - "akg": supports ["GPU", "CPU"].
228
232
  - "aot": supports ["GPU", "CPU", "Ascend"].
229
233
  - "pyfunc": supports ["CPU"].
230
234
  - "julia": supports ["CPU"].
@@ -233,11 +237,7 @@ class Custom(ops.PrimitiveWithInfer):
233
237
  func (Union[function, str]):
234
238
 
235
239
  - function: If func is of function type, then func should be a Python function which describes the
236
- computation logic of a user defined operator. The function can be one of the following:
237
-
238
- 1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
239
- 2. A pure python function
240
- 3. An kernel decorated function written by the Hybrid DSL.
240
+ computation logic of a user defined operator.
241
241
 
242
242
  - str: If func is of str type, then str should be a path of file along with a function name.
243
243
  This could be used when func_type is "aot" or "julia".
@@ -251,13 +251,11 @@ class Custom(ops.PrimitiveWithInfer):
251
251
 
252
252
  - "xxx.so" file generation:
253
253
 
254
- 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"),
255
- use nvcc command to compile
256
- it.(ex. :code:`nvcc --shared -Xcompiler -fPIC -o add.so add.cu`)
254
+ 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
255
+ it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
257
256
 
258
- 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"),
259
- use g++/gcc command to
260
- compile it.(ex. :code:`g++ --shared -fPIC -o add.so add.cc`)
257
+ 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to
258
+ compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
261
259
 
262
260
  - Define a "xxx.cc"/"xxx.cu" file:
263
261
 
@@ -305,14 +303,17 @@ class Custom(ops.PrimitiveWithInfer):
305
303
  b) Ascend platform.
306
304
  Before using Custom operators on the Ascend platform, users must first develop custom operators
307
305
  based on Ascend C and compile them. The complete development and usage process can refer to the
308
- tutorial `AOT-Type Custom Operators(Ascend) <https://www.mindspore.cn/docs/en/master/model_train/custom_program/operation/op_custom_ascendc.html>`_.
306
+ tutorial `AOT-Type Custom Operators(Ascend)
307
+ <https://www.mindspore.cn/tutorials/en/master/custom_program/operation/op_custom_ascendc.html>`_.
309
308
  By passing the name of the operator through the input parameter `func`, there are two usage methods
310
- based on the implementation of the infer shape function:
311
-
312
- - Python infer: If the operator's infer shape is implemented in Python, that is, the infer shape
313
- function is passed through the `out_shape` parameter, specify `func="CustomName"` .
314
- - C++ infer: If the operator's infer shape is implemented through C++, then pass the path of the
315
- infer shape implementation file in `func` and separate the operator name with `:`,
309
+ based on the implementation of the infer function:
310
+
311
+ - Python infer: If the operator's infer function is implemented in Python, that is, the infer shape
312
+ function is passed through the `out_shape` parameter, and the infer type is passed throuht the
313
+ `out_dtype`, then the `func` should be specified as the operator name, for example,
314
+ `func="CustomName"`.
315
+ - C++ infer: If the operator's infer function is implemented through C++, then pass the path of the
316
+ infer function implementation file in `func` and separate the operator name with `:`,
316
317
  for example: `func="add_custom_infer.cc:AddCustom"` .
317
318
 
318
319
  2. for "julia":
@@ -356,7 +357,7 @@ class Custom(ops.PrimitiveWithInfer):
356
357
 
357
358
  func_type (str): The implementation type of `func`, should be one of
358
359
 
359
- [ ``"hybrid"`` , ``"akg"`` , ``"aot"`` , ``"pyfunc"`` , ``"julia"`` ].
360
+ [ ``"aot"`` , ``"pyfunc"`` , ``"julia"`` ].
360
361
 
361
362
  bprop (function): The back propagation function of `func`. Default: ``None`` .
362
363
  reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
@@ -400,50 +401,14 @@ class Custom(ops.PrimitiveWithInfer):
400
401
  >>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
401
402
  >>> input_y = Tensor(np.ones([16, 16]).astype(np.float32))
402
403
  >>>
403
- >>> # Example, func_type = "hybrid"
404
- >>> # This is the default func_type in Custom,
405
- >>> # and both out_shape and out_dtype can be None(default value).
406
- >>> # In this case, the input func must be a function written in the Hybrid DSL
407
- >>> # and decorated by @kernel.
408
- >>> @kernel
409
- ... def add_script(a, b):
410
- ... c = output_tensor(a.shape, a.dtype)
411
- ... for i0 in range(a.shape[0]):
412
- ... for i1 in range(a.shape[1]):
413
- ... c[i0, i1] = a[i0, i1] + b[i0, i1]
414
- ... return c
415
- >>>
416
- >>> test_op_hybrid = ops.Custom(add_script)
417
- >>> output = test_op_hybrid(input_x, input_y)
418
- >>> # the result will be a 16 * 16 tensor with all elements 2
419
- >>> print(output.shape)
420
- (16, 16)
421
- >>> # Example, func_type = "aot"
422
- >>> def test_aot(x, y, out_shapes, out_types):
423
- ... program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot")
424
- ... out = program(x, y)
425
- ... return out
426
- >>>
427
404
  >>> # Example, func_type = "pyfunc"
428
- >>> def func_multi_output(x1, x2):
429
- ... return (x1 + x2), (x1 - x2)
405
+ >>> def func_pyfunc(x1, x2):
406
+ ... return x1 + x2
430
407
  >>>
431
- >>> test_pyfunc = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc")
408
+ >>> test_pyfunc = ops.Custom(func_pyfunc, lambda x, _: x, lambda x, _: x, "pyfunc")
432
409
  >>> output = test_pyfunc(input_x, input_y)
433
- >>>
434
- >>> # Example, func_type = "julia"
435
- >>> # julia code:
436
- >>> # add.jl
437
- >>> # module Add
438
- >>> # function add(x, y, z)
439
- >>> # z .= x + y
440
- >>> # return z
441
- >>> # end
442
- >>> # end
443
- >>> def test_julia(x, y, out_shapes, out_types):
444
- ... program = ops.Custom("./add.jl:Add:add", out_shapes, out_types, "julia")
445
- ... out = program(x, y)
446
- ... return out
410
+ >>> print(output.shape)
411
+ (16, 16)
447
412
  """
448
413
 
449
414
  registered_func = {}
@@ -469,6 +434,7 @@ class Custom(ops.PrimitiveWithInfer):
469
434
  self._func_compile_attrs = {}
470
435
  self._is_ms_kernel = False
471
436
  self.out_shape = out_shape
437
+ self.out_dtype = out_dtype
472
438
 
473
439
  self._check_platform()
474
440
  self._check_func()
@@ -486,13 +452,17 @@ class Custom(ops.PrimitiveWithInfer):
486
452
 
487
453
  if self.out_shape is None and self.func_type == "aot":
488
454
  self.add_prim_attr("cpp_infer_shape", True)
489
- self.out_dtype = out_dtype
455
+ if self.out_dtype is None and self.func_type == "aot":
456
+ self.add_prim_attr("cpp_infer_type", True)
457
+ self.multi_output = (reg_info is not None and (len(reg_info.get("outputs", [])) > 1))
458
+ self.add_prim_attr("multi_output", self.multi_output)
459
+
490
460
  self.bprop = bprop
491
461
  self.fake_output = False
492
462
  self.single_scalar_output = False
493
- if not self.out_dtype:
463
+ if not self.out_dtype and not self.func_type == "pyfunc":
494
464
  self.fake_output = True
495
- elif not self.out_shape:
465
+ elif not self.out_shape and self.func_type == "pyfunc":
496
466
  self.single_scalar_output = True
497
467
  self.add_prim_attr("fake_output", self.fake_output)
498
468
  self.add_prim_attr("single_scalar_output", self.single_scalar_output)
@@ -508,10 +478,10 @@ class Custom(ops.PrimitiveWithInfer):
508
478
 
509
479
  self.add_prim_attr("func_type", self.func_type)
510
480
  self._update_attr()
511
- self.enable_pyboost = False
512
- self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
513
- if context.get_context("device_target") == "Ascend" and self.func_type == "aot":
514
- self.enable_pyboost = True
481
+
482
+ self.enable_pyboost = (context.get_context("device_target") == "Ascend" and self.func_type == "aot")
483
+ if self.enable_pyboost:
484
+ self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
515
485
  for key, value in super().get_attr_dict().items():
516
486
  self.custom_pyboost.add_prim_attr(key, value)
517
487
 
@@ -554,10 +524,15 @@ class Custom(ops.PrimitiveWithInfer):
554
524
  infer_dtype = mstype.int32
555
525
  if self.func_type == "aot":
556
526
  if infer_shape is None:
557
- logger.warning("{}, 'out_shape' is None. Add a placeholder instead. "
558
- "A CPP version of infer shape function is required "
559
- "in this case.".format(self.log_prefix))
527
+ logger.debug("{}, 'out_shape' is None. Add a placeholder instead. "
528
+ "A CPP version of infer shape function is required "
529
+ "in this case.".format(self.log_prefix))
560
530
  infer_shape = (1,)
531
+ if infer_dtype is None:
532
+ logger.debug("{}, 'out_dtype' is None. Add a placeholder instead. "
533
+ "A CPP version of infer type function is required "
534
+ "in this case.".format(self.log_prefix))
535
+ infer_dtype = ms.float16
561
536
  # after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
562
537
  if not isinstance(infer_shape, (tuple, list)):
563
538
  raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"
@@ -1123,9 +1098,188 @@ class Custom(ops.PrimitiveWithInfer):
1123
1098
 
1124
1099
  def __call__(self, *args):
1125
1100
  if self.enable_pyboost:
1126
- return _convert_stub(pyboost_custom_ext(self.custom_pyboost, [args]))
1101
+ res = pyboost_custom_ext(self.custom_pyboost, [args])
1102
+ return res if self.multi_output else res[0]
1127
1103
  should_elim, output = self.check_elim(*args)
1128
1104
  if should_elim:
1129
1105
  return output
1130
1106
  # pylint: disable=protected-access
1131
1107
  return ops.primitive._run_op(self, self.name, args)
1108
+
1109
+
1110
+ class CustomOpBuilder:
1111
+ r"""
1112
+ CustomOpBuilder is used to initialize and configure custom operators for MindSpore.
1113
+ Users can define and load custom operator modules through this class and apply them to the network.
1114
+
1115
+ In most cases, users only need to provide the source files and additional compilation options in the constructor
1116
+ and call the `load` method to complete the compilation and loading of the operator.
1117
+ If users have specific customization requirements, they can inherit this class and override certain methods.
1118
+ It is important to note that if methods are overridden, some parameters passed to the constructor may be ignored.
1119
+
1120
+ .. warning::
1121
+ This is an experimental API that is subject to change.
1122
+
1123
+ Args:
1124
+ name (str): The unique name of the custom operator module, used to identify the operator.
1125
+ sources (Union[str, list[str]]): The source file(s) of the custom operator. It can be a single file path or
1126
+ a list of file paths.
1127
+ backend (str, optional): The target backend for the operator, such as "CPU" or "Ascend". Default: ``None``.
1128
+ include_paths (list[str], optional): Additionally included paths needed during compilation. Default: ``None``.
1129
+ cflags (str, optional): Extra C++ compiler flags to be used during compilation. Default: ``None``.
1130
+ ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
1131
+ kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
1132
+
1133
+ .. note::
1134
+ - If the `backend` argument is provided, additional default flags will be automatically added to
1135
+ the compilation and linking steps to support the operator's target backend. The default options
1136
+ can be referenced in the implementation of the `get_cflags` and `get_ldflags` methods in the `CustomOpBuilder
1137
+ <https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/operations/custom_ops.py>`_.
1138
+ - The `sources` argument must point to valid source files for the custom operator.
1139
+
1140
+ Supported Platforms:
1141
+ ``Ascend`` ``CPU``
1142
+
1143
+ Examples:
1144
+ >>> from mindspore import ops
1145
+ >>> builder = ops.CustomOpBuilder(
1146
+ ... name="custom_op_cpu",
1147
+ ... sources="custom_ops_impl/pybind_op_cpu.cpp",
1148
+ ... backend="CPU"
1149
+ ... )
1150
+ >>> my_ops = builder.load()
1151
+ """
1152
+ _mindspore_path = None
1153
+ _loaded_ops = {}
1154
+ _ms_code_base = None
1155
+
1156
+ def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
1157
+ self.name = name
1158
+ self.source = sources
1159
+ self.backend = backend
1160
+ self.include_paths = include_paths
1161
+ self.cflags = cflags
1162
+ self.ldflags = ldflags
1163
+ if CustomOpBuilder._mindspore_path is None:
1164
+ CustomOpBuilder._mindspore_path = os.path.dirname(os.path.abspath(ms.__file__))
1165
+ CustomOpBuilder._ms_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include")
1166
+ if self.backend == "Ascend":
1167
+ self.ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
1168
+
1169
+ def get_sources(self):
1170
+ """
1171
+ Get the source files for the custom operator.
1172
+
1173
+ Returns:
1174
+ str or list[str], The source file(s) for the operator.
1175
+ """
1176
+ return self.source
1177
+
1178
+ def get_include_paths(self):
1179
+ """
1180
+ Get the include paths required for compiling the custom operator.
1181
+
1182
+ Returns:
1183
+ list[str], A list of include paths.
1184
+ """
1185
+ include_list = self.include_paths if self.include_paths is not None else []
1186
+ include_list.append(CustomOpBuilder._mindspore_path)
1187
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include"))
1188
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party"))
1189
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/robin_hood_hashing"))
1190
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/securec/include"))
1191
+
1192
+ if self.backend == "Ascend":
1193
+ include_list.append(os.path.join(self.ascend_cann_path, "include"))
1194
+ include_list += self._get_ms_inner_includes()
1195
+ return include_list
1196
+
1197
+ def _get_ms_inner_includes(self):
1198
+ """include paths for inner module interface."""
1199
+ ms_inner_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include", "mindspore")
1200
+ include_list = []
1201
+ include_list.append(ms_inner_code_base + "/core/include")
1202
+ include_list.append(ms_inner_code_base + "/core/mindrt/include")
1203
+ include_list.append(ms_inner_code_base + "/core/mindrt")
1204
+ include_list.append(ms_inner_code_base + "/ops")
1205
+ include_list.append(ms_inner_code_base + "/ops/kernel/include")
1206
+ include_list.append(ms_inner_code_base + "/ccsrc")
1207
+ include_list.append(ms_inner_code_base + "/ccsrc/include")
1208
+ include_list.append(ms_inner_code_base + "/ccsrc/minddata/mindrecord/include")
1209
+ return include_list
1210
+
1211
+ def get_cflags(self):
1212
+ """
1213
+ Get the C++ compiler flags for building the custom operator.
1214
+
1215
+ Returns:
1216
+ list[str], A list of C++ compiler flags.
1217
+ """
1218
+ flags = ['-fstack-protector-all', '-fPIC', '-pie']
1219
+ flags += ['-DENABLE_FAST_HASH_TABLE=1']
1220
+ if self.backend == "Ascend":
1221
+ flags.append('-DCUSTOM_ASCEND_OP')
1222
+ if self.cflags is not None:
1223
+ flags.append(self.cflags)
1224
+ return flags
1225
+
1226
+ def get_ldflags(self):
1227
+ """
1228
+ Get the linker flags for building the custom operator.
1229
+
1230
+ Returns:
1231
+ list[str], A list of linker flags.
1232
+ """
1233
+ flags = ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath', '-s']
1234
+ flags += [
1235
+ '-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib')),
1236
+ '-lmindspore_core',
1237
+ '-lmindspore_ms_backend',
1238
+ '-lmindspore_pynative'
1239
+ ]
1240
+ if self.backend == "Ascend":
1241
+ flags.append('-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib/plugin')))
1242
+ flags.append('-L' + os.path.abspath(os.path.join(self.ascend_cann_path, "lib64")))
1243
+ flags.append('-lascendcl')
1244
+ flags.append('-l:libmindspore_ascend.so.2')
1245
+ if self.ldflags is not None:
1246
+ flags.append(self.ldflags)
1247
+ return flags
1248
+
1249
+ def build(self):
1250
+ """
1251
+ Build the custom operator module.
1252
+
1253
+ This method generates a dynamic library file for the custom operator based on the provided source files,
1254
+ include paths, compilation flags, and link flags.
1255
+
1256
+ Returns:
1257
+ str, The path to the compiled module.
1258
+ """
1259
+ return ExtensionBuilder().build(
1260
+ module_name=self.name,
1261
+ sources=self.get_sources(),
1262
+ extra_include_paths=self.get_include_paths(),
1263
+ extra_cflags=self.get_cflags(),
1264
+ extra_ldflags=self.get_ldflags())
1265
+
1266
+ def load(self):
1267
+ """
1268
+ Build and load the custom operator module.
1269
+
1270
+ Returns:
1271
+ Module, The loaded custom operator module.
1272
+ """
1273
+ if self.name in CustomOpBuilder._loaded_ops:
1274
+ return CustomOpBuilder._loaded_ops[self.name]
1275
+ module_path = self.build()
1276
+ mod = self._import_module(module_path)
1277
+ CustomOpBuilder._loaded_ops[self.name] = mod
1278
+ return mod
1279
+
1280
+ def _import_module(self, module_path):
1281
+ """Import module from library."""
1282
+ spec = importlib.util.spec_from_file_location(self.name, module_path)
1283
+ module = importlib.util.module_from_spec(spec)
1284
+ spec.loader.exec_module(module)
1285
+ return module