mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,1087 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """
16
- Generate operator definition from ops.yaml
16
+ Auto generate ops files.
17
17
  """
18
- import os
19
- import re
20
- import shutil
21
- import pathlib
22
18
  import logging
23
- import gen_utils
24
- from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
25
- merge_files_append, safe_load_yaml, convert_dtype_str, write_file)
26
- from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
27
- import template
28
- from template import CppTemplate
29
- from gen_pyboost_func import gen_pyboost_code
30
- from gen_aclnn_implement import gen_aclnn_kernel
31
- import gen_constants as K
32
19
 
20
+ from resources.resource_manager import prepare_resources
21
+ from common import gen_utils
33
22
 
34
- def _get_op_name(yaml_key, yaml_value):
35
- """
36
- Get op name for python class Primitive or c++ OpDef name.
37
- """
38
- # If has class item, use the specified item.
39
- class_def = yaml_value.get("class")
40
- if class_def is not None:
41
- class_name_specify = class_def.get("name")
42
- if class_name_specify is not None:
43
- return class_name_specify
44
- # Else use the default rule generate class name.
45
- op_name = yaml_key
46
- class_name_normal = ''.join(word.capitalize() for word in op_name.split('_'))
47
- return class_name_normal
23
+ from op_def.gen_op_def import generate_ops_def_files
24
+ from op_def_py.gen_op_def_py import generate_ops_py_files
25
+ from api.gen_api import generate_api_files
26
+ from aclnn.aclnn_kernel_register_auto_cc_generator import generate_aclnn_reg_file
27
+ from pyboost.gen_pyboost_func import gen_pyboost_code
48
28
 
49
29
 
50
- def _get_op_func_name(yaml_key, yaml_value):
51
- func_def = yaml_value.get('function')
52
- func_name = yaml_key
53
-
54
- if func_def is not None:
55
- item = func_def.get("name")
56
- if item is not None:
57
- func_name = item
58
- return func_name
59
-
60
-
61
- def _auto_generate_class_disabled(yaml_value):
62
- """Check whether class can be auto generated."""
63
- if 'class' not in yaml_value.keys():
64
- return False
65
- class_def = yaml_value.get("class")
66
- if 'disable' not in class_def.keys():
67
- return False
68
- disable_item = class_def.get("disable")
69
- if disable_item is True:
70
- return True
71
- if disable_item is False:
72
- return False
73
- raise TypeError(f"The disable label for class should be True or False, but get {disable_item}.")
74
-
75
-
76
- def _auto_generate_func_disabled(yaml_value):
77
- """Check whether function can be auto generated."""
78
- if 'function' not in yaml_value.keys():
79
- return False
80
- func_def = yaml_value.get('function')
81
- if 'disable' not in func_def.keys():
82
- return False
83
- disable_item = func_def.get("disable")
84
- if disable_item is True:
85
- return True
86
- if disable_item is False:
87
- return False
88
- raise TypeError(f"The disable label for function should be True or False, but get {disable_item}.")
89
-
90
-
91
- def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
92
- """
93
- Generate signature rw code
94
- """
95
- for rw_arg_name in write_list:
96
- if rw_arg_name == arg_name:
97
- return ', sig.sig_rw.RW_WRITE'
98
- for read_arg_name in read_list:
99
- if read_arg_name == arg_name:
100
- return ', sig.sig_rw.RW_READ'
101
- for ref_arg_name in ref_list:
102
- if ref_arg_name == arg_name:
103
- return ', sig.sig_rw.RW_REF'
104
- return ''
105
-
106
-
107
- def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
108
- """
109
- Generate cc signature rw code
110
- """
111
- rw_label = 'kRWDefault'
112
- for op in write_list:
113
- if op == rw_op_name:
114
- rw_label = 'kRWWrite'
115
- for op in read_list:
116
- if op == rw_op_name:
117
- rw_label = 'kRWRead'
118
- for op in ref_list:
119
- if op == rw_op_name:
120
- rw_label = 'kRWRef'
121
- return 'SignatureEnumRW::' + rw_label
122
-
123
-
124
- def signature_get_enum_dtype_cc(index):
125
- """
126
- Generate cc enum dtype code
127
- """
128
- enum_type = 'SignatureEnumDType::'
129
- type_map = {0: 'kDType',
130
- 1: 'kDType1',
131
- 2: 'kDType2',
132
- 3: 'kDType3',
133
- 4: 'kDType4',
134
- 5: 'kDType5',
135
- 6: 'kDType6',
136
- 7: 'kDType7',
137
- 8: 'kDType8',
138
- 9: 'kDType9'}
139
- if index in type_map:
140
- return enum_type + type_map[index]
141
- return enum_type + 'kDTypeEmptyDefaultValue'
142
-
143
-
144
- def signature_get_dtype_label(index):
145
- """
146
- Generate signature dtype code
147
- """
148
- dtype_index = ''
149
- if index > 0:
150
- dtype_index = f"""{index}"""
151
- return f"""dtype=sig.sig_dtype.T{dtype_index}"""
152
-
153
-
154
- def get_same_dtype_groups(args_signature, args_name):
155
- """
156
- Get same dtype groups
157
- """
158
- same_dtype_groups = {}
159
- dtype_conut = 0
160
- if args_signature is None:
161
- return same_dtype_groups, dtype_conut
162
-
163
- dtype_group = args_signature.get('dtype_group')
164
- if dtype_group is not None:
165
- args_list = []
166
- match = re.findall(r'\((.*?)\)', dtype_group)
167
- for item in match:
168
- args_list.append(item.replace(' ', '').split(","))
169
- for arg_name in args_name:
170
- if arg_name in same_dtype_groups:
171
- continue
172
- is_match = False
173
- for group in args_list:
174
- if arg_name in group:
175
- is_match = True
176
- for item in group:
177
- same_dtype_groups[item] = dtype_conut
178
- break
179
- if not is_match:
180
- same_dtype_groups[arg_name] = dtype_conut
181
- dtype_conut = dtype_conut + 1
182
- return same_dtype_groups, dtype_conut
183
-
184
-
185
- def generate_py_op_signature(op_name, args_signature, args_name, args_default):
186
- """
187
- Generate __mindspore_signature__
188
- """
189
-
190
- def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
191
- for sig_arg_name in sig_arg_names:
192
- if sig_arg_name not in args_names:
193
- raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
194
-
195
- if args_signature is None and not args_default:
196
- return ''
197
-
198
- signature_code = f""" __mindspore_signature__ = """
199
-
200
- # Init rw.
201
- write_list = []
202
- read_list = []
203
- ref_list = []
204
- if args_signature is not None:
205
- rw_write = args_signature.get('rw_write')
206
- rw_read = args_signature.get('rw_read')
207
- rw_ref = args_signature.get('rw_ref')
208
- if rw_write is not None:
209
- write_list = rw_write.replace(' ', '').split(",")
210
- _check_signature_arg_valid(op_name, write_list, args_name)
211
- if rw_read is not None:
212
- read_list = rw_read.replace(' ', '').split(",")
213
- _check_signature_arg_valid(op_name, read_list, args_name)
214
- if rw_ref is not None:
215
- ref_list = rw_ref.replace(' ', '').split(",")
216
- _check_signature_arg_valid(op_name, ref_list, args_name)
217
- # Init dtype group.
218
- same_dtype_groups, dtype_conut = get_same_dtype_groups(args_signature, args_name)
219
- _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
220
- # Only one dtype_group is set.
221
- if dtype_conut == 1 and not any([write_list, read_list, ref_list, args_default]):
222
- signature_code += '('
223
- for _ in range(len(args_name) - 1):
224
- signature_code += 'sig.sig_dtype.T, '
225
- signature_code += 'sig.sig_dtype.T)\n\n'
226
- return signature_code
227
-
228
- # Set sig.make_sig.
229
- signature_code += f""" (\n"""
230
- for arg_name in args_name:
231
- signature_code += f""" sig.make_sig('{arg_name}'"""
232
- signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
233
- if arg_name in same_dtype_groups:
234
- signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
235
- if arg_name in args_default:
236
- signature_code += f""", default=""" + str(args_default[arg_name])
237
- signature_code += f"""),\n"""
238
- signature_code += f""" )\n\n"""
239
- return signature_code
240
-
241
-
242
- def generate_cc_op_signature(args_signature, args_name):
243
- """
244
- generate signatures on in cc file
245
- :param args_signature:
246
- :param args_name:
247
- :return:
248
- """
249
- if args_signature is None:
250
- return ''
251
- signature_code = ''
252
- # Init rw.
253
- write_list = []
254
- read_list = []
255
- ref_list = []
256
- if args_signature is not None:
257
- rw_write = args_signature.get('rw_write')
258
- rw_read = args_signature.get('rw_read')
259
- rw_ref = args_signature.get('rw_ref')
260
- if rw_write is not None:
261
- write_list = rw_write.replace(' ', '').split(",")
262
- if rw_read is not None:
263
- read_list = rw_read.replace(' ', '').split(",")
264
- if rw_ref is not None:
265
- ref_list = rw_ref.replace(' ', '').split(",")
266
- # Init dtype group.
267
- same_dtype_groups, _ = get_same_dtype_groups(args_signature, args_name)
268
- for arg_name in args_name:
269
- enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
270
- enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
271
- signature = f"""Signature("{arg_name}", {enum_rw}, \
272
- SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
273
- signature_code += signature
274
- return signature_code
275
-
276
-
277
- def generate_py_op_deprecated(deprecated):
278
- """
279
- Generate @deprecated
280
- """
281
- if deprecated is None:
282
- return ''
283
- version = deprecated.get("version")
284
- if version is None:
285
- raise ValueError("The version of deprecated can't be None.")
286
- substitute = deprecated.get("substitute")
287
- if substitute is None:
288
- raise ValueError("The substitute of deprecated can't be None.")
289
- use_substitute = deprecated.get("use_substitute")
290
- if use_substitute is None:
291
- raise ValueError("The use_substitute of deprecated can't be None.")
292
- if use_substitute is not True and use_substitute is not False:
293
- raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
294
-
295
- deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
296
- return deprecated
297
-
298
-
299
- def _normalize_func_description_fromat(description):
300
- """
301
- Process description.
302
- """
303
- if not description:
304
- return description
305
- lines = description.split("\n")
306
- if len(lines) == 1:
307
- return description
308
- # Add line indentation to other lines after the first line
309
- for i in range(1, len(lines)):
310
- indent = " " if lines[i] else ""
311
- lines[i] = indent + lines[i]
312
- # Remove trailing blank lines
313
- lines = lines if lines[-1] != "" else lines[:-1]
314
- description = "\n".join(lines)
315
- return description
316
-
317
-
318
- def _get_op_description(operator_name, doc_str):
319
- """
320
- Generate ops api description.
321
- """
322
- if doc_str is None:
323
- print(f"Description is None, op_name: {operator_name}")
324
- return ""
325
- description = doc_str.get(operator_name)
326
- if description is None:
327
- print(f"Description is None, op_name: {operator_name}")
328
- return ""
329
- description = description.get("description")
330
- if description is None:
331
- print(f"Description is None, op_name: {operator_name}")
332
- return ""
333
- return _normalize_func_description_fromat(description)
334
-
335
-
336
- def generate_py_op_func(yaml_data, doc_data):
337
- """
338
- Generate operator python function api.
339
- """
340
- gen_py = ''
341
-
342
- for operator_name, operator_data in yaml_data.items():
343
- if _auto_generate_func_disabled(operator_data):
344
- continue
345
- func_name = _get_op_func_name(operator_name, operator_data)
346
- args = operator_data.get('args')
347
- class_name = _get_op_name(operator_name, operator_data)
348
- func_args = []
349
- prim_init_args = []
350
- prim_call_args = []
351
- for arg_name, arg_info in args.items():
352
- is_prim_init = arg_info.get('prim_init')
353
- has_default = 'default' in arg_info.keys()
354
-
355
- # step1: Process function args.
356
- if not has_default:
357
- func_args.append(f"""{arg_name}""")
358
- else:
359
- default_value = arg_info.get('default')
360
- func_args.append(f"""{arg_name}={default_value}""")
361
-
362
- # step2: Process primitive object init args.
363
- if is_prim_init:
364
- prim_init_args.append(arg_name)
365
-
366
- # step3: Process primitive object call args.
367
- else:
368
- prim_call_args.append(arg_name)
369
- description = _get_op_description(operator_name, doc_data)
370
- function_code = f"""\n
371
- def {func_name}({', '.join(arg for arg in func_args)}):
372
- r\"\"\"
373
- {description}
374
- \"\"\"
375
- {operator_name}_op = _get_cache_prim({class_name})({', '.join(arg_name for arg_name in prim_init_args)})
376
- return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
377
-
378
- if not prim_init_args:
379
- if _auto_generate_class_disabled(operator_data):
380
- gen_py += f"""\n{operator_name}_op={class_name}()"""
381
- function_code = f"""\n
382
- def {func_name}({', '.join(arg for arg in func_args)}):
383
- r\"\"\"
384
- {description}
385
- \"\"\"
386
- return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
387
- else:
388
- dis = operator_data.get("dispatch")
389
- if dis is not None:
390
- enable_pyboost = dis.get("enable")
391
- if enable_pyboost:
392
- function_code = f"""\n
393
- def {func_name}({', '.join(arg for arg in func_args)}):
394
- r\"\"\"
395
- {description}
396
- \"\"\"
397
- return {operator_name}_impl({', '.join(arg_name for arg_name, _ in args.items())})\n"""
398
- gen_py += function_code
399
-
400
- return gen_py
401
-
402
-
403
- def get_dtype(arg_info):
404
- dtype = arg_info.get('dtype')
405
- # Currently, TypeId is represented by int
406
- if dtype == 'TypeId':
407
- dtype = 'int'
408
- return dtype
409
-
410
-
411
- def process_args(class_name, args):
412
- """
413
- Process arg for yaml, get arg_name, init value, type cast, arg_handler, etc.
414
- """
415
- inputs_name = []
416
- args_name = []
417
- args_assign = []
418
- inputs_default = {}
419
- init_args_with_default = []
420
- args_handlers = {}
421
- for arg_name, arg_info in args.items():
422
- dtype = get_dtype(arg_info)
423
- default_value = arg_info.get('default')
424
- has_default = 'default' in arg_info.keys()
425
- is_prim_init = arg_info.get('prim_init')
426
- arg_handler = arg_info.get('arg_handler')
427
-
428
- # step1: get args infos:
429
- if is_prim_init:
430
- # step1.1: get args name:
431
- args_name.append(arg_name)
432
- # step1.2: get args assign with default value:
433
- if has_default:
434
- init_args_with_default.append(f"""{arg_name}={default_value}""")
435
- else:
436
- init_args_with_default.append(f"""{arg_name}""")
437
-
438
- # step1.3: get args set prim arg expression:
439
- assign_str = gen_utils.get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype)
440
- if arg_handler:
441
- assign_str = f""" self._set_prim_arg_with_handler("{arg_name}", {assign_str}, {arg_handler})"""
442
- else:
443
- assign_str = f""" self._set_prim_arg("{arg_name}", {assign_str})"""
444
- args_assign.append(assign_str)
445
- # step2: get inputs infos:
446
- else:
447
- # step2.1: get inputs name:
448
- inputs_name.append(arg_name)
449
-
450
- # step2.2: get default value of inputs:
451
- if has_default:
452
- inputs_default[arg_name] = default_value
453
-
454
- # step2.3: get args_handler functions for inputs
455
- if arg_handler:
456
- args_handlers[arg_name] = arg_handler
457
-
458
- return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
459
-
460
-
461
- def generate_pyboost_import_header(yaml_data):
462
- """
463
- Generate python primitive
464
- """
465
- pyboost_import_header = ''
466
- import_pyboost = CppTemplate("from mindspore._c_expression import $var\n")
467
- for operator_name, operator_data in yaml_data.items():
468
- is_pyboost = is_pyboost_enable(operator_data)
469
- if is_pyboost:
470
- header = import_pyboost.replace(var=get_pyboost_name(operator_name))
471
- pyboost_import_header += header
472
- return pyboost_import_header
473
-
474
-
475
- def _generate_class_description(class_name, func_name, input_args, init_args, func_disabled, doc_str):
476
- """Generate description for every primitive definition."""
477
- if func_disabled:
478
- # if function disabled, function name is equal to operator_name
479
- description = _get_op_description(func_name, doc_str)
480
- description = f""" r\"\"\"
481
- {description}
482
- \"\"\"
483
- """
484
- return description
485
-
486
- # If function is an released API, refer to the function doc.
487
- description_str = f""" r\"\"\"
488
- .. code-block::
489
-
490
- prim = ops.{class_name}({', '.join(init_args)})
491
- out = prim({', '.join(input_args)})
492
-
493
- is equivalent to
494
-
495
- .. code-block::
496
-
497
- ops.{func_name}({", ".join(input_args + init_args)})
498
-
499
- Refer to :func:`mindspore.ops.{func_name}` for more details.
500
- \"\"\"
501
- """
502
- return description_str
503
-
504
-
505
- def get_init_code(init_code, operator_data):
506
- """
507
- Generate init code for primitive
508
- """
509
- labels = operator_data.get('labels')
510
- if labels is not None:
511
- if init_code != "":
512
- init_code += "\n"
513
- init_code += \
514
- '\n'.join([f""" self.add_prim_attr("{key}", {value})""" for key, value in labels.items()])
515
- if init_code == "":
516
- init_code = f""" pass"""
517
- return init_code
518
-
519
-
520
- def generate_py_primitive(yaml_data, doc_str):
521
- """
522
- Generate python primitive
523
- """
524
-
525
- def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
526
- """Generate arg_handler"""
527
- arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
528
- if is_optional:
529
- arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
530
- return arg_handler_call
531
-
532
- gen_py = ''
533
- for operator_name, operator_data in yaml_data.items():
534
- if _auto_generate_class_disabled(operator_data):
535
- continue
536
- class_name = _get_op_name(operator_name, operator_data)
537
- func_name = _get_op_func_name(operator_name, operator_data)
538
- pyboost_func_name = get_pyboost_name(operator_name)
539
- args = operator_data.get('args')
540
- inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = \
541
- process_args(class_name, args)
542
- init_code = '\n'.join(args_assign)
543
- signature_code = generate_py_op_signature(class_name, operator_data.get('args_signature'), inputs_args,
544
- inputs_default)
545
- deprecated_code = generate_py_op_deprecated(operator_data.get('deprecated'))
546
- init_code = get_init_code(init_code, operator_data)
547
- primitive_code = f"""\n
548
- class {class_name}(Primitive):\n"""
549
- func_disabled = _auto_generate_func_disabled(operator_data)
550
- primitive_code += _generate_class_description(class_name, func_name, inputs_args, init_args, func_disabled,
551
- doc_str)
552
- if signature_code != "":
553
- primitive_code += signature_code
554
- if deprecated_code != "":
555
- primitive_code += deprecated_code
556
- primitive_code += f""" @prim_arg_register
557
- def __init__(self"""
558
- if init_args_with_default:
559
- primitive_code += ", " + f"""{', '.join(init_args_with_default) if init_args_with_default else ''}"""
560
- call_args = []
561
- for name in inputs_args:
562
- call_args.append(f"""{name}={inputs_default[name]}""" if name in inputs_default else name)
563
- primitive_code += f"""):
564
- {init_code}
565
-
566
- def __call__(self, {', '.join(call_args)}):"""
567
- is_pyboost = is_pyboost_enable(operator_data)
568
- if is_pyboost:
569
- primitive_code += f"""
570
- return _convert_stub({pyboost_func_name}(self, ["""
571
- else:
572
- primitive_code += f"""
573
- return super().__call__("""
574
- if inputs_args:
575
- args_with_handler = []
576
- for arg in inputs_args:
577
- if arg in args_handlers:
578
- is_optional = inputs_default.get(arg) == "None"
579
- args_with_handler.append(_generate_arg_handler(class_name, arg, args_handlers[arg], is_optional))
580
- else:
581
- args_with_handler.append(arg)
582
- primitive_code += ', '.join(args_with_handler)
583
-
584
- if init_args:
585
- primitive_code += ', '
586
- primitive_code += ', '.join([f'self.{arg}' for arg in init_args])
587
- if is_pyboost:
588
- primitive_code += """]))"""
589
- else:
590
- primitive_code += """)
591
- """
592
-
593
- gen_py += primitive_code
594
- if not init_args:
595
- prim_op_object = f"""\n
596
- {operator_name}_op={class_name}()
597
- """
598
- gen_py += prim_op_object
599
- return gen_py
600
-
601
-
602
- def generate_op_name_opdef(yaml_data):
603
- """
604
- Generate op name
605
- """
606
- op_name_head = f"""
607
- #ifndef MINDSPORE_CORE_OP_NAME_H_
608
- #define MINDSPORE_CORE_OP_NAME_H_
609
-
610
- namespace mindspore::ops {{
611
- """
612
-
613
- op_name_end = f"""}} // namespace mindspore::ops
614
-
615
- #endif // MINDSPORE_CORE_OP_NAME_H_
616
- """
617
-
618
- op_name_gen = ''
619
- op_name_gen += op_name_head
620
- for operator_name, operator_data in yaml_data.items():
621
- k_name_op = _get_op_name(operator_name, operator_data)
622
- op_name_gen += f"""constexpr auto kName{k_name_op} = "{k_name_op}";
623
- """
624
-
625
- op_name_gen += op_name_end
626
- return op_name_gen
627
-
628
-
629
- def generate_op_prim_opdef(yaml_data):
630
- """
631
- Generate primitive c++ definition
632
- """
633
- ops_prim_head = f"""
634
- #ifndef MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
635
- #define MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
636
-
637
- #include <memory>
638
- #include "ir/anf.h"
639
- #include "ir/primitive.h"
640
- #include "{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
641
- #include "mindapi/base/macros.h"
642
-
643
- namespace mindspore::prim {{
644
- """
645
-
646
- ops_prim_end = f"""}} // namespace mindspore::prim
647
- #endif // MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
648
- """
649
-
650
- ops_prim_gen = ''
651
- ops_prim_gen += ops_prim_head
652
- for operator_name, operator_data in yaml_data.items():
653
- k_name_op = _get_op_name(operator_name, operator_data)
654
- ops_prim_gen += f"""GVAR_DEF(PrimitivePtr, kPrim{k_name_op}, std::make_shared<Primitive>(ops::kName{k_name_op}))
655
- """
656
- ops_prim_gen += ops_prim_end
657
- return ops_prim_gen
658
-
659
-
660
- def generate_lite_ops(yaml_data):
661
- """
662
- Generate BaseOperator parameter set and get func
663
- """
664
- lite_ops_h_head = f"""
665
- #ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
666
- #define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
667
-
668
- #include <vector>
669
- #include "ops/base_operator.h"
670
- #include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
671
-
672
- namespace mindspore::ops {{
673
- """
674
-
675
- lite_ops_h_end = f"""}} // namespace mindspore::ops
676
- #endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
677
- """
678
-
679
- lite_ops_cc_head = f"""
680
- #include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_lite_ops.h"
681
- #include "mindapi/helper.h"
682
- #include "ops/primitive_c.h"
683
- #include "ops/base_operator.h"
684
- #include "abstract/abstract_value.h"
685
-
686
- namespace mindspore::ops {{
687
- """
688
-
689
- lite_ops_cc_end = f"""}} // namespace mindspore::ops
690
- """
691
-
692
- lite_ops_h_gen = ''
693
- lite_ops_cc_gen = ''
694
-
695
- lite_ops_h_gen += lite_ops_h_head
696
- lite_ops_cc_gen += lite_ops_cc_head
697
- for operator_name, operator_data in yaml_data.items():
698
- op_name = _get_op_name(operator_name, operator_data)
699
- lite_ops_h_gen += f"""class OPS_API {op_name} : public BaseOperator {{
700
- public:
701
- MIND_API_BASE_MEMBER({op_name});
702
- {op_name}() : BaseOperator(kName{op_name}) {{}}\n"""
703
- args = operator_data.get('args')
704
- for _, (arg_name, arg_info) in enumerate(args.items()):
705
- is_prim_init = arg_info.get('prim_init')
706
- if not is_prim_init:
707
- continue
708
-
709
- dtype = get_dtype(arg_info)
710
- if dtype == "str":
711
- dtype = "std::string"
712
- if dtype in ("tuple[str]", "list[str]"):
713
- dtype = "std::vector<std::string>"
714
- if dtype in ("tuple[int]", "list[int]"):
715
- dtype = "std::vector<int64_t>"
716
- if dtype in ("tuple[float]", "list[float]"):
717
- dtype = "std::vector<float>"
718
- if dtype in ("tuple[bool]", "list[bool]"):
719
- dtype = "std::vector<bool>"
720
- if dtype == "int":
721
- dtype = "int64_t"
722
- lite_ops_h_gen += f""" void set_{arg_name}(const {dtype} &{arg_name});\n"""
723
- lite_ops_h_gen += f""" {dtype} get_{arg_name}() const;\n"""
724
-
725
- lite_ops_cc_gen += f"""void {op_name}::set_{arg_name}(const {dtype} &{arg_name}) \
726
- {{ (void)this->AddAttr("{arg_name}", api::MakeValue({arg_name})); }}\n\n"""
727
- lite_ops_cc_gen += f"""{dtype} {op_name}::get_{arg_name}() const \
728
- {{ return GetValue<{dtype}>(GetAttr("{arg_name}")); }}\n\n"""
729
-
730
- op_name = _get_op_name(operator_name, operator_data)
731
- lite_ops_cc_gen += f"""REGISTER_PRIMITIVE_C(kName{op_name}, {op_name});\n"""
732
- lite_ops_cc_gen += f"""MIND_API_OPERATOR_IMPL({op_name}, BaseOperator);\n\n"""
733
- lite_ops_h_gen += f"""}};\n\n"""
734
- lite_ops_h_gen += lite_ops_h_end
735
- lite_ops_cc_gen += lite_ops_cc_end
736
- return lite_ops_h_gen, lite_ops_cc_gen
737
-
738
-
739
- def generate_cc_opdef(yaml_data):
740
- """
741
- Generate c++ OpDef
742
- """
743
- gen_cc_code = f"""\n
744
- namespace mindspore::ops {{"""
745
- gen_include = f"""\n
746
- #include \"{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_def.h\""""
747
- gen_include += f"""
748
- #include \"ir/signature.h\""""
749
-
750
- for operator_name, operator_data in yaml_data.items():
751
- args = operator_data.get('args')
752
- class_name = _get_op_name(operator_name, operator_data)
753
- inputs_args, _, _, _, _, _ = process_args(class_name, args)
754
- signature_code = generate_cc_op_signature(operator_data.get('args_signature'), inputs_args)
755
- args = operator_data.get('args')
756
- returns = operator_data.get('returns')
757
- dispatch = operator_data.get("dispatch")
758
- # dispatch not defined in yaml or dispatch.enable==False
759
- if not dispatch or not dispatch.get("enable"):
760
- dispatch = "false"
761
- else:
762
- dispatch = "true"
763
- enable_dispatch_str = f"""{dispatch}"""
764
-
765
- is_view = operator_data.get('view')
766
- if is_view:
767
- is_view_s = "true"
768
- else:
769
- is_view_s = "false"
770
- is_view_str = f"""{is_view_s}"""
771
-
772
- gen_include += f"""\n#include "{K.MS_OPS_FUNC_IMPL_PATH}/{operator_name}.h\""""
773
- cc_index_str = ''
774
- input_args_str = ''
775
- args_dict = {}
776
- for i, (arg_name, arg_info) in enumerate(args.items()):
777
- args_dict[arg_name] = i
778
- cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
779
- dtype = get_dtype(arg_info)
780
- cc_dtype_str = convert_dtype_str(dtype)
781
-
782
- is_prim_init = 1 if arg_info.get('prim_init') else 0
783
- arg_handler = arg_info.get('arg_handler')
784
- arg_handler_str = "" if arg_handler is None else arg_handler
785
-
786
- type_cast = arg_info.get('type_cast')
787
- type_cast_str = "" if type_cast is None else \
788
- ', '.join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in
789
- (ct.strip() for ct in type_cast.split(",")))
790
-
791
- # default: None is regarded as a optional argument.
792
- is_optional_str = "false"
793
- if 'default' in arg_info.keys() and arg_info.get('default') == "None":
794
- is_optional_str = "true"
795
-
796
- input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
797
- f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
798
- f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
799
-
800
- # Process outputs.
801
- return_args_str = ''
802
- for return_name, return_info in returns.items():
803
- return_dtype = return_info.get('dtype')
804
- ref_name = return_info.get('inplace')
805
- ref_index_str = -1 if ref_name is None else args_dict.get(ref_name)
806
- cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
807
- return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
808
- /*.inplace_input_index_=*/{ref_index_str}}},\n"""
809
-
810
- op_def_cc = template.OP_PROTO_TEMPLATE.replace(class_name=class_name, input_args=input_args_str,
811
- return_args=return_args_str, signatures=signature_code,
812
- indexes=cc_index_str, enable_dispatch=enable_dispatch_str,
813
- is_view=is_view_str)
814
- gen_cc_code += op_def_cc
815
- if is_view:
816
- view_op_def = op_def_cc.replace(class_name, class_name+"View")
817
- gen_cc_code += view_op_def
818
-
819
- cc_opdef_end = f"""\n}} // namespace mindspore::ops\n"""
820
- return gen_include + gen_cc_code + cc_opdef_end
821
-
822
-
823
- ops_py_prim_header = f"""
824
- \"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\"
825
-
826
- from mindspore.ops.primitive import Primitive, prim_arg_register
827
- from mindspore.ops import signature as sig
828
- from mindspore.common import dtype as mstype
829
- from mindspore.common._decorator import deprecated
830
- from mindspore.ops._primitive_cache import _get_cache_prim
831
- from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
832
- from mindspore.ops.auto_generate.gen_arg_handler import *
833
- from mindspore._c_expression import OpDtype
834
- from mindspore.common._stub_tensor import _convert_stub
835
- """
836
-
837
- ops_py_def_header = f"""
838
- \"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
839
-
840
- from .gen_ops_prim import *
841
- from .pyboost_inner_prim import *
842
- from mindspore.ops.operations.manually_defined.ops_def import *
843
- from mindspore.ops._primitive_cache import _get_cache_prim
844
- """
845
-
846
-
847
- def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
848
- py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_prim.py')
849
- tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_prim.py')
850
- pyboost_import_header = generate_pyboost_import_header(yaml_str)
851
- py_prim = generate_py_primitive(yaml_str, doc_str)
852
- write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim)
853
- check_change_and_replace_file(py_path, tmp_py_path)
854
-
855
-
856
- def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre):
857
- py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_def.py')
858
- tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_def.py')
859
- py_func = generate_py_op_func(yaml_str, doc_str)
860
- write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func)
861
- check_change_and_replace_file(py_path, tmp_py_path)
862
-
863
-
864
- def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre):
865
- """
866
- Generate ops python file from yaml.
867
- """
868
- generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre)
869
- generate_ops_def_file(work_path, yaml_str, doc_str, file_pre)
870
- shutil.copy(os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_auto_generate_init.txt'),
871
- os.path.join(work_path, K.PY_AUTO_GEN_PATH, "__init__.py"))
872
-
873
-
874
- def generate_ops_cc_files(work_path, yaml_str):
875
- """
876
- Generate ops c++ file from yaml.
877
- """
878
- # ops_def
879
- op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_def.cc')
880
- tmp_op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_def.cc')
881
- cc_def_code = generate_cc_opdef(yaml_str)
882
- write_file(tmp_op_cc_path, cc_license_str + cc_def_code)
883
- check_change_and_replace_file(op_cc_path, tmp_op_cc_path)
884
-
885
- # ops_primitive
886
- op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_primitive.h')
887
- tmp_op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_primitive.h')
888
- op_prim_code = generate_op_prim_opdef(yaml_str)
889
- write_file(tmp_op_prim_path, cc_license_str + op_prim_code)
890
- check_change_and_replace_file(op_prim_path, tmp_op_prim_path)
891
-
892
- # lite_h_ops
893
- lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.h')
894
- tmp_lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.h')
895
- lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str)
896
- write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code)
897
- check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path)
898
-
899
- # lite_cc_ops
900
- lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.cc')
901
- tmp_lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.cc')
902
- write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code)
903
- check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path)
904
-
905
- # ops_names
906
- op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_name.h')
907
- tmp_op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_name.h')
908
- op_name_code = generate_op_name_opdef(yaml_str)
909
- write_file(tmp_op_name_path, cc_license_str + op_name_code)
910
- check_change_and_replace_file(op_name_path, tmp_op_name_path)
911
-
912
-
913
- def generate_op_labels(yaml_data):
914
- """
915
- Generate python labels
916
- """
917
- gen_label_py = f"""op_labels = {{"""
918
- for operator_name, operator_data in yaml_data.items():
919
- labels = operator_data.get('labels')
920
- if labels is not None:
921
- class_name = _get_op_name(operator_name, operator_data)
922
- gen_label_py += f"""
923
- "{class_name}": {{"""
924
- gen_label_py += f""", """.join([f""""{key}": {value}""" for key, value in labels.items()])
925
- gen_label_py += f"""}},"""
926
- gen_label_py += f"""
927
- }}"""
928
- return gen_label_py
929
-
930
-
931
- def generate_op_arg_default_value(yaml_data):
932
- """
933
- Generate python default value.
934
- """
935
- default_py_header = f"""\"\"\"Operator labels and args default value.\"\"\"
936
- from mindspore.common import dtype as mstype\n\n"""
937
-
938
- gen_default_py = default_py_header + f"""op_args_default_value = {{"""
939
- for operator_name, operator_data in yaml_data.items():
940
- arg_default_dict = {}
941
- args = operator_data.get('args')
942
- for arg_name, arg_info in args.items():
943
- arg_default = arg_info.get('default')
944
- if arg_default is not None:
945
- arg_default_dict[arg_name] = arg_default
946
- if arg_default_dict:
947
- class_name = _get_op_name(operator_name, operator_data)
948
- gen_default_py += f"""
949
- "{class_name}": {{"""
950
- gen_default_py += f""", """.join([f""""{key}": {value}""" for key, value in arg_default_dict.items()])
951
- gen_default_py += f"""}},"""
952
- gen_default_py += f"""
953
- }}"""
954
- return gen_default_py
955
-
956
-
957
- def generate_create_instance_helper_file(work_path, yaml_str):
958
- """
959
- Generate C++ helper file from yaml.
960
- """
961
- dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
962
- op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py')
963
- tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py')
964
- py_labels = generate_op_labels(yaml_str)
965
- py_arg_default = generate_op_arg_default_value(yaml_str)
966
- write_file(tmp_op_py_path, py_licence_str + "\n" + py_arg_default + "\n\n" + py_labels + "\n")
967
- check_change_and_replace_file(op_py_path, tmp_op_py_path)
968
-
969
-
970
- def generate_aclnn_reg_code(yaml_data):
971
- """generate aclnn register code"""
972
- current_path = os.path.dirname(os.path.realpath(__file__))
973
- work_path = os.path.join(current_path, '../../../../')
974
- ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, "ops.yaml")
975
- yaml_str = gen_utils.safe_load_yaml(ops_yaml_path)
976
-
977
- reg_code = f"""
978
- #include "{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
979
-
980
- namespace mindspore {{
981
- namespace kernel {{
982
- """
983
- for operator_name, operator_data in yaml_data.items():
984
- dispatch = operator_data.get("dispatch")
985
- if not dispatch or not dispatch.get("enable"):
986
- continue
987
- Ascend = dispatch.get("Ascend")
988
- if Ascend is not None: # KernelMod is provided by yaml, don't auto generate it.
989
- continue
990
- _, _, none_tensor_exist = get_dtypes(operator_data)
991
- if none_tensor_exist:
992
- gen_aclnn_kernel(operator_name, yaml_str, auto=True)
993
- continue
994
- class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
995
- op_class = operator_data.get("class")
996
- if op_class and op_class.get("name") is not None:
997
- class_name = op_class.get("name")
998
- inputs_outputs_num = len(operator_data.get("args")) + len(operator_data.get("returns"))
999
- aclnn_name = AclnnUtils.get_aclnn_interface(class_name)
1000
- reg_code += f"""
1001
- MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});"""
1002
- reg_code += f"""
1003
- }} // namespace kernel
1004
- }} // namespace mindspore
1005
- """
1006
- return reg_code
1007
-
1008
-
1009
- def generate_aclnn_reg_file(work_path, yaml_str):
1010
- """
1011
- Generate nnacl kernelmod register
1012
- """
1013
- tmp_register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/tmp_aclnn_kernel_register.cc'
1014
- register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_register_auto.cc'
1015
- reg_code = generate_aclnn_reg_code(yaml_str)
1016
- write_file(tmp_register_file, cc_license_str + reg_code)
1017
- check_change_and_replace_file(register_file, tmp_register_file)
1018
-
1019
-
1020
- def generate_arg_handler_files(work_path):
1021
- """
1022
- Generate arg handler files.
1023
- """
1024
- dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
1025
- src_arg_handler_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_handler.py')
1026
- dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
1027
- tmp_dst_arg_handler_path = os.path.join(dst_dir, 'tmp_gen_arg_handler.py')
1028
- if not os.path.exists(dst_dir):
1029
- os.makedirs(dst_dir, mode=0o700)
1030
- shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
1031
- check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
1032
-
1033
- src_arg_dtype_cast_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_dtype_cast.py')
1034
- dst_arg_dtype_cast_path = os.path.join(dst_dir, 'gen_arg_dtype_cast.py')
1035
- tmp_arg_dtype_cast_path = os.path.join(dst_dir, 'tmp_arg_dtype_cast.py')
1036
- shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
1037
- check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
1038
-
1039
-
1040
- def get_view_ops(yaml_data):
1041
- """
1042
- Get ops with view: True
1043
- """
1044
- view_ops = []
1045
- for operator_name, operator_data in yaml_data.items():
1046
- class_name = _get_op_name(operator_name, operator_data)
1047
- view = operator_data.get("view")
1048
- if view:
1049
- view_ops.append(class_name + "View")
1050
- return view_ops
30
+ module_generators = [
31
+ generate_ops_py_files, # generate ops python files
32
+ generate_ops_def_files, # generate ops definition files
33
+ gen_pyboost_code, # generate pyboost code
34
+ generate_aclnn_reg_file, # generate aclnn kernelmod register
35
+ generate_api_files # generate api definition files
36
+ ]
1051
37
 
1052
38
 
1053
39
  def main():
1054
- current_path = os.path.dirname(os.path.realpath(__file__))
1055
- work_path = os.path.join(current_path, '../../../../')
1056
-
1057
- # merge ops yaml
1058
- ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops.yaml')
1059
- doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_doc.yaml')
1060
-
1061
- ops_yaml_dir_path = os.path.join(work_path, K.MS_YAML_PATH)
1062
- infer_ops_yaml_dir_path = os.path.join(ops_yaml_dir_path, "infer")
1063
- doc_yaml_dir_path = os.path.join(ops_yaml_dir_path, "doc")
1064
- merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1065
- merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1066
- merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
1067
-
1068
- # make auto_generate dir
1069
- cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
1070
- pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
1071
-
1072
- # generate arg_handler files
1073
- generate_arg_handler_files(work_path)
1074
-
1075
- # read ops definition str and doc str
1076
- ops_yaml_str = safe_load_yaml(ops_yaml_path)
1077
- doc_yaml_str = safe_load_yaml(doc_yaml_path)
1078
-
1079
- # generate ops python files
1080
- generate_ops_py_files(work_path, ops_yaml_str, doc_yaml_str, "gen")
40
+ resource_mgr = prepare_resources()
1081
41
 
1082
- # generate ops c++ files
1083
- generate_ops_cc_files(work_path, ops_yaml_str)
1084
- # generate create prim instance helper file
1085
- generate_create_instance_helper_file(work_path, ops_yaml_str)
1086
- # get view extra ops
1087
- extra_ops = get_view_ops(ops_yaml_str)
1088
- # generate pyboost code
1089
- gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str, extra_ops)
1090
- # generate aclnn kernelmod register
1091
- generate_aclnn_reg_file(work_path, ops_yaml_str)
42
+ for generator in module_generators:
43
+ generator(resource_mgr)
1092
44
 
45
+ gen_utils.clear_obsolete_auto_gen_files()
1093
46
 
1094
47
  if __name__ == "__main__":
1095
48
  try:
1096
49
  main()
1097
- # pylint: disable=broad-except
1098
- except Exception as e:
50
+ except Exception as e: # pylint: disable=broad-except
1099
51
  logging.critical("Auto generate failed, err info: %s", e)
52
+ raise e