mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ import numpy as np
29
29
  import mindspore as ms
30
30
  from mindspore._c_expression import Oplib, typing
31
31
  from mindspore._c_expression import pyboost_custom_ext
32
- from mindspore.common._stub_tensor import _convert_stub
33
32
  from mindspore import context
34
33
  from mindspore.common import Tensor
35
34
  from mindspore.common import dtype as mstype
@@ -40,6 +39,7 @@ from mindspore.communication.management import get_rank, GlobalComm
40
39
  from ._ms_kernel import determine_variable_usage
41
40
  from ._custom_grad import autodiff_bprop
42
41
  from ._pyfunc_registry import add_pyfunc
42
+ from ._custom_ops_utils import ExtensionBuilder
43
43
 
44
44
  if platform.system() != "Windows":
45
45
  import fcntl
@@ -183,10 +183,16 @@ class _CustomExt(ops.PrimitiveWithInfer):
183
183
 
184
184
  infer_value = None
185
185
  if infer_shape is None:
186
- logger.warning("'out_shape' is None. Add a placeholder instead. "
187
- "A CPP version of infer shape function is required "
188
- "in this case.")
186
+ logger.debug("'out_shape' is None. Add a placeholder instead. "
187
+ "A CPP version of infer shape function is required "
188
+ "in this case.")
189
189
  infer_shape = (1,)
190
+ if infer_dtype is None:
191
+ logger.debug("'out_dtype' is None. Add a placeholder instead. "
192
+ "A CPP version of infer type function is required "
193
+ "in this case.")
194
+ infer_dtype = ms.float16
195
+
190
196
  # after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
191
197
  if not isinstance(infer_shape, (tuple, list)):
192
198
  raise TypeError("'out_shape' must be one of [tuple, list, function], but got {}".format(type(infer_shape)))
@@ -215,7 +221,7 @@ class Custom(ops.PrimitiveWithInfer):
215
221
  function if needed. Then these `Custom` objects can be directly used in neural networks.
216
222
  Detailed description and introduction of user-defined operators, including correct writing of parameters,
217
223
  please refer to `Custom Operators Tutorial
218
- <https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_ .
224
+ <https://www.mindspore.cn/tutorials/en/master/custom_program/op_custom.html>`_ .
219
225
 
220
226
  .. warning::
221
227
  - This is an experimental API that is subject to change.
@@ -223,8 +229,6 @@ class Custom(ops.PrimitiveWithInfer):
223
229
  .. note::
224
230
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
225
231
 
226
- - "hybrid": supports ["GPU", "CPU"].
227
- - "akg": supports ["GPU", "CPU"].
228
232
  - "aot": supports ["GPU", "CPU", "Ascend"].
229
233
  - "pyfunc": supports ["CPU"].
230
234
  - "julia": supports ["CPU"].
@@ -233,11 +237,7 @@ class Custom(ops.PrimitiveWithInfer):
233
237
  func (Union[function, str]):
234
238
 
235
239
  - function: If func is of function type, then func should be a Python function which describes the
236
- computation logic of a user defined operator. The function can be one of the following:
237
-
238
- 1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
239
- 2. A pure python function
240
- 3. An kernel decorated function written by the Hybrid DSL.
240
+ computation logic of a user defined operator.
241
241
 
242
242
  - str: If func is of str type, then str should be a path of file along with a function name.
243
243
  This could be used when func_type is "aot" or "julia".
@@ -251,13 +251,11 @@ class Custom(ops.PrimitiveWithInfer):
251
251
 
252
252
  - "xxx.so" file generation:
253
253
 
254
- 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"),
255
- use nvcc command to compile
256
- it.(ex. :code:`nvcc --shared -Xcompiler -fPIC -o add.so add.cu`)
254
+ 1) GPU Platform: Given user defined "xxx.cu" file (ex. "{path}/add.cu"), use nvcc command to compile
255
+ it.(ex. "nvcc --shared -Xcompiler -fPIC -o add.so add.cu")
257
256
 
258
- 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"),
259
- use g++/gcc command to
260
- compile it.(ex. :code:`g++ --shared -fPIC -o add.so add.cc`)
257
+ 2) CPU Platform: Given user defined "xxx.cc" file (ex. "{path}/add.cc"), use g++/gcc command to
258
+ compile it.(ex. "g++ --shared -fPIC -o add.so add.cc")
261
259
 
262
260
  - Define a "xxx.cc"/"xxx.cu" file:
263
261
 
@@ -305,14 +303,17 @@ class Custom(ops.PrimitiveWithInfer):
305
303
  b) Ascend platform.
306
304
  Before using Custom operators on the Ascend platform, users must first develop custom operators
307
305
  based on Ascend C and compile them. The complete development and usage process can refer to the
308
- tutorial `AOT-Type Custom Operators(Ascend) <https://www.mindspore.cn/docs/en/master/model_train/custom_program/operation/op_custom_ascendc.html>`_.
306
+ tutorial `AOT-Type Custom Operators(Ascend)
307
+ <https://www.mindspore.cn/tutorials/en/master/custom_program/operation/op_custom_ascendc.html>`_.
309
308
  By passing the name of the operator through the input parameter `func`, there are two usage methods
310
- based on the implementation of the infer shape function:
311
-
312
- - Python infer: If the operator's infer shape is implemented in Python, that is, the infer shape
313
- function is passed through the `out_shape` parameter, specify `func="CustomName"` .
314
- - C++ infer: If the operator's infer shape is implemented through C++, then pass the path of the
315
- infer shape implementation file in `func` and separate the operator name with `:`,
309
+ based on the implementation of the infer function:
310
+
311
+ - Python infer: If the operator's infer function is implemented in Python, that is, the infer shape
312
+ function is passed through the `out_shape` parameter, and the infer type is passed throuht the
313
+ `out_dtype`, then the `func` should be specified as the operator name, for example,
314
+ `func="CustomName"`.
315
+ - C++ infer: If the operator's infer function is implemented through C++, then pass the path of the
316
+ infer function implementation file in `func` and separate the operator name with `:`,
316
317
  for example: `func="add_custom_infer.cc:AddCustom"` .
317
318
 
318
319
  2. for "julia":
@@ -356,7 +357,7 @@ class Custom(ops.PrimitiveWithInfer):
356
357
 
357
358
  func_type (str): The implementation type of `func`, should be one of
358
359
 
359
- [ ``"hybrid"`` , ``"akg"`` , ``"aot"`` , ``"pyfunc"`` , ``"julia"`` ].
360
+ [ ``"aot"`` , ``"pyfunc"`` , ``"julia"`` ].
360
361
 
361
362
  bprop (function): The back propagation function of `func`. Default: ``None`` .
362
363
  reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
@@ -400,50 +401,14 @@ class Custom(ops.PrimitiveWithInfer):
400
401
  >>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
401
402
  >>> input_y = Tensor(np.ones([16, 16]).astype(np.float32))
402
403
  >>>
403
- >>> # Example, func_type = "hybrid"
404
- >>> # This is the default func_type in Custom,
405
- >>> # and both out_shape and out_dtype can be None(default value).
406
- >>> # In this case, the input func must be a function written in the Hybrid DSL
407
- >>> # and decorated by @kernel.
408
- >>> @kernel
409
- ... def add_script(a, b):
410
- ... c = output_tensor(a.shape, a.dtype)
411
- ... for i0 in range(a.shape[0]):
412
- ... for i1 in range(a.shape[1]):
413
- ... c[i0, i1] = a[i0, i1] + b[i0, i1]
414
- ... return c
415
- >>>
416
- >>> test_op_hybrid = ops.Custom(add_script)
417
- >>> output = test_op_hybrid(input_x, input_y)
418
- >>> # the result will be a 16 * 16 tensor with all elements 2
419
- >>> print(output.shape)
420
- (16, 16)
421
- >>> # Example, func_type = "aot"
422
- >>> def test_aot(x, y, out_shapes, out_types):
423
- ... program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot")
424
- ... out = program(x, y)
425
- ... return out
426
- >>>
427
404
  >>> # Example, func_type = "pyfunc"
428
- >>> def func_multi_output(x1, x2):
429
- ... return (x1 + x2), (x1 - x2)
405
+ >>> def func_pyfunc(x1, x2):
406
+ ... return x1 + x2
430
407
  >>>
431
- >>> test_pyfunc = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc")
408
+ >>> test_pyfunc = ops.Custom(func_pyfunc, lambda x, _: x, lambda x, _: x, "pyfunc")
432
409
  >>> output = test_pyfunc(input_x, input_y)
433
- >>>
434
- >>> # Example, func_type = "julia"
435
- >>> # julia code:
436
- >>> # add.jl
437
- >>> # module Add
438
- >>> # function add(x, y, z)
439
- >>> # z .= x + y
440
- >>> # return z
441
- >>> # end
442
- >>> # end
443
- >>> def test_julia(x, y, out_shapes, out_types):
444
- ... program = ops.Custom("./add.jl:Add:add", out_shapes, out_types, "julia")
445
- ... out = program(x, y)
446
- ... return out
410
+ >>> print(output.shape)
411
+ (16, 16)
447
412
  """
448
413
 
449
414
  registered_func = {}
@@ -469,6 +434,7 @@ class Custom(ops.PrimitiveWithInfer):
469
434
  self._func_compile_attrs = {}
470
435
  self._is_ms_kernel = False
471
436
  self.out_shape = out_shape
437
+ self.out_dtype = out_dtype
472
438
 
473
439
  self._check_platform()
474
440
  self._check_func()
@@ -486,13 +452,17 @@ class Custom(ops.PrimitiveWithInfer):
486
452
 
487
453
  if self.out_shape is None and self.func_type == "aot":
488
454
  self.add_prim_attr("cpp_infer_shape", True)
489
- self.out_dtype = out_dtype
455
+ if self.out_dtype is None and self.func_type == "aot":
456
+ self.add_prim_attr("cpp_infer_type", True)
457
+ self.multi_output = (reg_info is not None and (len(reg_info.get("outputs", [])) > 1))
458
+ self.add_prim_attr("multi_output", self.multi_output)
459
+
490
460
  self.bprop = bprop
491
461
  self.fake_output = False
492
462
  self.single_scalar_output = False
493
- if not self.out_dtype:
463
+ if not self.out_dtype and not self.func_type == "pyfunc":
494
464
  self.fake_output = True
495
- elif not self.out_shape:
465
+ elif not self.out_shape and self.func_type == "pyfunc":
496
466
  self.single_scalar_output = True
497
467
  self.add_prim_attr("fake_output", self.fake_output)
498
468
  self.add_prim_attr("single_scalar_output", self.single_scalar_output)
@@ -508,10 +478,10 @@ class Custom(ops.PrimitiveWithInfer):
508
478
 
509
479
  self.add_prim_attr("func_type", self.func_type)
510
480
  self._update_attr()
511
- self.enable_pyboost = False
512
- self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
513
- if context.get_context("device_target") == "Ascend" and self.func_type == "aot":
514
- self.enable_pyboost = True
481
+
482
+ self.enable_pyboost = (context.get_context("device_target") == "Ascend" and self.func_type == "aot")
483
+ if self.enable_pyboost:
484
+ self.custom_pyboost = _CustomExt(self.func, self.out_shape, self.out_dtype, self.bprop)
515
485
  for key, value in super().get_attr_dict().items():
516
486
  self.custom_pyboost.add_prim_attr(key, value)
517
487
 
@@ -554,10 +524,15 @@ class Custom(ops.PrimitiveWithInfer):
554
524
  infer_dtype = mstype.int32
555
525
  if self.func_type == "aot":
556
526
  if infer_shape is None:
557
- logger.warning("{}, 'out_shape' is None. Add a placeholder instead. "
558
- "A CPP version of infer shape function is required "
559
- "in this case.".format(self.log_prefix))
527
+ logger.debug("{}, 'out_shape' is None. Add a placeholder instead. "
528
+ "A CPP version of infer shape function is required "
529
+ "in this case.".format(self.log_prefix))
560
530
  infer_shape = (1,)
531
+ if infer_dtype is None:
532
+ logger.debug("{}, 'out_dtype' is None. Add a placeholder instead. "
533
+ "A CPP version of infer type function is required "
534
+ "in this case.".format(self.log_prefix))
535
+ infer_dtype = ms.float16
561
536
  # after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
562
537
  if not isinstance(infer_shape, (tuple, list)):
563
538
  raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"
@@ -1123,9 +1098,188 @@ class Custom(ops.PrimitiveWithInfer):
1123
1098
 
1124
1099
  def __call__(self, *args):
1125
1100
  if self.enable_pyboost:
1126
- return _convert_stub(pyboost_custom_ext(self.custom_pyboost, [args]))
1101
+ res = pyboost_custom_ext(self.custom_pyboost, [args])
1102
+ return res if self.multi_output else res[0]
1127
1103
  should_elim, output = self.check_elim(*args)
1128
1104
  if should_elim:
1129
1105
  return output
1130
1106
  # pylint: disable=protected-access
1131
1107
  return ops.primitive._run_op(self, self.name, args)
1108
+
1109
+
1110
+ class CustomOpBuilder:
1111
+ r"""
1112
+ CustomOpBuilder is used to initialize and configure custom operators for MindSpore.
1113
+ Users can define and load custom operator modules through this class and apply them to the network.
1114
+
1115
+ In most cases, users only need to provide the source files and additional compilation options in the constructor
1116
+ and call the `load` method to complete the compilation and loading of the operator.
1117
+ If users have specific customization requirements, they can inherit this class and override certain methods.
1118
+ It is important to note that if methods are overridden, some parameters passed to the constructor may be ignored.
1119
+
1120
+ .. warning::
1121
+ This is an experimental API that is subject to change.
1122
+
1123
+ Args:
1124
+ name (str): The unique name of the custom operator module, used to identify the operator.
1125
+ sources (Union[str, list[str]]): The source file(s) of the custom operator. It can be a single file path or
1126
+ a list of file paths.
1127
+ backend (str, optional): The target backend for the operator, such as "CPU" or "Ascend". Default: ``None``.
1128
+ include_paths (list[str], optional): Additionally included paths needed during compilation. Default: ``None``.
1129
+ cflags (str, optional): Extra C++ compiler flags to be used during compilation. Default: ``None``.
1130
+ ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
1131
+ kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
1132
+
1133
+ .. note::
1134
+ - If the `backend` argument is provided, additional default flags will be automatically added to
1135
+ the compilation and linking steps to support the operator's target backend. The default options
1136
+ can be referenced in the implementation of the `get_cflags` and `get_ldflags` methods in the `CustomOpBuilder
1137
+ <https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/operations/custom_ops.py>`_.
1138
+ - The `sources` argument must point to valid source files for the custom operator.
1139
+
1140
+ Supported Platforms:
1141
+ ``Ascend`` ``CPU``
1142
+
1143
+ Examples:
1144
+ >>> from mindspore import ops
1145
+ >>> builder = ops.CustomOpBuilder(
1146
+ ... name="custom_op_cpu",
1147
+ ... sources="custom_ops_impl/pybind_op_cpu.cpp",
1148
+ ... backend="CPU"
1149
+ ... )
1150
+ >>> my_ops = builder.load()
1151
+ """
1152
+ _mindspore_path = None
1153
+ _loaded_ops = {}
1154
+ _ms_code_base = None
1155
+
1156
+ def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
1157
+ self.name = name
1158
+ self.source = sources
1159
+ self.backend = backend
1160
+ self.include_paths = include_paths
1161
+ self.cflags = cflags
1162
+ self.ldflags = ldflags
1163
+ if CustomOpBuilder._mindspore_path is None:
1164
+ CustomOpBuilder._mindspore_path = os.path.dirname(os.path.abspath(ms.__file__))
1165
+ CustomOpBuilder._ms_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include")
1166
+ if self.backend == "Ascend":
1167
+ self.ascend_cann_path = os.getenv("ASCEND_OPP_PATH").split('opp')[0]
1168
+
1169
+ def get_sources(self):
1170
+ """
1171
+ Get the source files for the custom operator.
1172
+
1173
+ Returns:
1174
+ str or list[str], The source file(s) for the operator.
1175
+ """
1176
+ return self.source
1177
+
1178
+ def get_include_paths(self):
1179
+ """
1180
+ Get the include paths required for compiling the custom operator.
1181
+
1182
+ Returns:
1183
+ list[str], A list of include paths.
1184
+ """
1185
+ include_list = self.include_paths if self.include_paths is not None else []
1186
+ include_list.append(CustomOpBuilder._mindspore_path)
1187
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include"))
1188
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party"))
1189
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/robin_hood_hashing"))
1190
+ include_list.append(os.path.join(CustomOpBuilder._mindspore_path, "include/third_party/securec/include"))
1191
+
1192
+ if self.backend == "Ascend":
1193
+ include_list.append(os.path.join(self.ascend_cann_path, "include"))
1194
+ include_list += self._get_ms_inner_includes()
1195
+ return include_list
1196
+
1197
+ def _get_ms_inner_includes(self):
1198
+ """include paths for inner module interface."""
1199
+ ms_inner_code_base = os.path.join(CustomOpBuilder._mindspore_path, "include", "mindspore")
1200
+ include_list = []
1201
+ include_list.append(ms_inner_code_base + "/core/include")
1202
+ include_list.append(ms_inner_code_base + "/core/mindrt/include")
1203
+ include_list.append(ms_inner_code_base + "/core/mindrt")
1204
+ include_list.append(ms_inner_code_base + "/ops")
1205
+ include_list.append(ms_inner_code_base + "/ops/kernel/include")
1206
+ include_list.append(ms_inner_code_base + "/ccsrc")
1207
+ include_list.append(ms_inner_code_base + "/ccsrc/include")
1208
+ include_list.append(ms_inner_code_base + "/ccsrc/minddata/mindrecord/include")
1209
+ return include_list
1210
+
1211
+ def get_cflags(self):
1212
+ """
1213
+ Get the C++ compiler flags for building the custom operator.
1214
+
1215
+ Returns:
1216
+ list[str], A list of C++ compiler flags.
1217
+ """
1218
+ flags = ['-fstack-protector-all', '-fPIC', '-pie']
1219
+ flags += ['-DENABLE_FAST_HASH_TABLE=1']
1220
+ if self.backend == "Ascend":
1221
+ flags.append('-DCUSTOM_ASCEND_OP')
1222
+ if self.cflags is not None:
1223
+ flags.append(self.cflags)
1224
+ return flags
1225
+
1226
+ def get_ldflags(self):
1227
+ """
1228
+ Get the linker flags for building the custom operator.
1229
+
1230
+ Returns:
1231
+ list[str], A list of linker flags.
1232
+ """
1233
+ flags = ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath', '-s']
1234
+ flags += [
1235
+ '-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib')),
1236
+ '-lmindspore_core',
1237
+ '-lmindspore_ms_backend',
1238
+ '-lmindspore_pynative'
1239
+ ]
1240
+ if self.backend == "Ascend":
1241
+ flags.append('-L' + os.path.abspath(os.path.join(CustomOpBuilder._mindspore_path, 'lib/plugin')))
1242
+ flags.append('-L' + os.path.abspath(os.path.join(self.ascend_cann_path, "lib64")))
1243
+ flags.append('-lascendcl')
1244
+ flags.append('-l:libmindspore_ascend.so.2')
1245
+ if self.ldflags is not None:
1246
+ flags.append(self.ldflags)
1247
+ return flags
1248
+
1249
+ def build(self):
1250
+ """
1251
+ Build the custom operator module.
1252
+
1253
+ This method generates a dynamic library file for the custom operator based on the provided source files,
1254
+ include paths, compilation flags, and link flags.
1255
+
1256
+ Returns:
1257
+ str, The path to the compiled module.
1258
+ """
1259
+ return ExtensionBuilder().build(
1260
+ module_name=self.name,
1261
+ sources=self.get_sources(),
1262
+ extra_include_paths=self.get_include_paths(),
1263
+ extra_cflags=self.get_cflags(),
1264
+ extra_ldflags=self.get_ldflags())
1265
+
1266
+ def load(self):
1267
+ """
1268
+ Build and load the custom operator module.
1269
+
1270
+ Returns:
1271
+ Module, The loaded custom operator module.
1272
+ """
1273
+ if self.name in CustomOpBuilder._loaded_ops:
1274
+ return CustomOpBuilder._loaded_ops[self.name]
1275
+ module_path = self.build()
1276
+ mod = self._import_module(module_path)
1277
+ CustomOpBuilder._loaded_ops[self.name] = mod
1278
+ return mod
1279
+
1280
+ def _import_module(self, module_path):
1281
+ """Import module from library."""
1282
+ spec = importlib.util.spec_from_file_location(self.name, module_path)
1283
+ module = importlib.util.module_from_spec(spec)
1284
+ spec.loader.exec_module(module)
1285
+ return module