mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +46 -197
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +217 -98
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +435 -371
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +2 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +951 -1992
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +314 -566
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/experimental/es/embedding_service.py +35 -27
  112. mindspore/experimental/llm_boost/__init__.py +1 -0
  113. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  114. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  115. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  116. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  117. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  118. mindspore/experimental/llm_boost/register.py +1 -0
  119. mindspore/experimental/map_parameter.py +4 -4
  120. mindspore/experimental/optim/adadelta.py +6 -6
  121. mindspore/experimental/optim/adagrad.py +4 -4
  122. mindspore/experimental/optim/adam.py +7 -0
  123. mindspore/experimental/optim/adamax.py +4 -4
  124. mindspore/experimental/optim/adamw.py +4 -0
  125. mindspore/experimental/optim/asgd.py +1 -1
  126. mindspore/experimental/optim/lr_scheduler.py +73 -46
  127. mindspore/experimental/optim/radam.py +34 -31
  128. mindspore/experimental/optim/rprop.py +1 -1
  129. mindspore/experimental/optim/sgd.py +1 -1
  130. mindspore/hal/contiguous_tensors_handle.py +6 -10
  131. mindspore/hal/device.py +55 -53
  132. mindspore/hal/event.py +52 -52
  133. mindspore/hal/memory.py +157 -117
  134. mindspore/hal/stream.py +150 -109
  135. mindspore/include/api/context.h +0 -1
  136. mindspore/include/dataset/constants.h +7 -4
  137. mindspore/include/dataset/execute.h +2 -2
  138. mindspore/jpeg62.dll +0 -0
  139. mindspore/log.py +50 -0
  140. mindspore/mindrecord/__init__.py +21 -8
  141. mindspore/mindrecord/config.py +17 -316
  142. mindspore/mindrecord/filereader.py +1 -9
  143. mindspore/mindrecord/filewriter.py +5 -15
  144. mindspore/mindrecord/mindpage.py +1 -9
  145. mindspore/mindspore_backend_common.dll +0 -0
  146. mindspore/mindspore_backend_manager.dll +0 -0
  147. mindspore/mindspore_common.dll +0 -0
  148. mindspore/mindspore_core.dll +0 -0
  149. mindspore/mindspore_dump.dll +0 -0
  150. mindspore/mindspore_frontend.dll +0 -0
  151. mindspore/mindspore_memory_pool.dll +0 -0
  152. mindspore/mindspore_ms_backend.dll +0 -0
  153. mindspore/mindspore_ops.dll +0 -0
  154. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  155. mindspore/mindspore_ops_kernel_common.dll +0 -0
  156. mindspore/mindspore_profiler.dll +0 -0
  157. mindspore/mindspore_pyboost.dll +0 -0
  158. mindspore/mindspore_pynative.dll +0 -0
  159. mindspore/mindspore_res_manager.dll +0 -0
  160. mindspore/mindspore_runtime_pipeline.dll +0 -0
  161. mindspore/mint/__init__.py +796 -759
  162. mindspore/mint/distributed/__init__.py +70 -4
  163. mindspore/mint/distributed/distributed.py +2679 -44
  164. mindspore/mint/linalg/__init__.py +8 -0
  165. mindspore/mint/nn/__init__.py +743 -22
  166. mindspore/mint/nn/functional.py +716 -23
  167. mindspore/mint/nn/layer/__init__.py +21 -4
  168. mindspore/mint/nn/layer/_functions.py +334 -0
  169. mindspore/mint/nn/layer/activation.py +276 -1
  170. mindspore/mint/nn/layer/basic.py +123 -0
  171. mindspore/mint/nn/layer/conv.py +921 -0
  172. mindspore/mint/nn/layer/normalization.py +223 -28
  173. mindspore/mint/nn/layer/padding.py +797 -0
  174. mindspore/mint/nn/layer/pooling.py +235 -0
  175. mindspore/mint/optim/__init__.py +3 -1
  176. mindspore/mint/optim/adam.py +223 -0
  177. mindspore/mint/optim/adamw.py +26 -19
  178. mindspore/mint/optim/sgd.py +171 -0
  179. mindspore/mint/special/__init__.py +2 -1
  180. mindspore/multiprocessing/__init__.py +5 -0
  181. mindspore/nn/__init__.py +4 -1
  182. mindspore/nn/cell.py +1370 -189
  183. mindspore/nn/dynamic_lr.py +2 -1
  184. mindspore/nn/layer/activation.py +29 -27
  185. mindspore/nn/layer/basic.py +51 -35
  186. mindspore/nn/layer/channel_shuffle.py +3 -3
  187. mindspore/nn/layer/container.py +1 -1
  188. mindspore/nn/layer/conv.py +22 -17
  189. mindspore/nn/layer/embedding.py +12 -11
  190. mindspore/nn/layer/normalization.py +56 -49
  191. mindspore/nn/layer/padding.py +4 -3
  192. mindspore/nn/layer/pooling.py +120 -42
  193. mindspore/nn/layer/rnn_cells.py +1 -1
  194. mindspore/nn/layer/rnns.py +2 -1
  195. mindspore/nn/layer/timedistributed.py +5 -5
  196. mindspore/nn/layer/transformer.py +59 -36
  197. mindspore/nn/learning_rate_schedule.py +8 -4
  198. mindspore/nn/loss/loss.py +58 -55
  199. mindspore/nn/optim/ada_grad.py +7 -5
  200. mindspore/nn/optim/adadelta.py +11 -9
  201. mindspore/nn/optim/adafactor.py +1 -1
  202. mindspore/nn/optim/adam.py +17 -13
  203. mindspore/nn/optim/adamax.py +8 -7
  204. mindspore/nn/optim/adasum.py +5 -5
  205. mindspore/nn/optim/asgd.py +1 -1
  206. mindspore/nn/optim/ftrl.py +11 -9
  207. mindspore/nn/optim/lamb.py +1 -1
  208. mindspore/nn/optim/lars.py +1 -4
  209. mindspore/nn/optim/lazyadam.py +12 -10
  210. mindspore/nn/optim/momentum.py +7 -6
  211. mindspore/nn/optim/optimizer.py +3 -3
  212. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  213. mindspore/nn/optim/rmsprop.py +13 -12
  214. mindspore/nn/optim/rprop.py +11 -9
  215. mindspore/nn/optim/sgd.py +9 -6
  216. mindspore/nn/optim/tft_wrapper.py +5 -2
  217. mindspore/nn/optim/thor.py +2 -1
  218. mindspore/nn/probability/bijector/bijector.py +17 -11
  219. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  220. mindspore/nn/probability/bijector/invert.py +2 -2
  221. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  222. mindspore/nn/probability/bijector/softplus.py +3 -2
  223. mindspore/nn/probability/distribution/beta.py +3 -3
  224. mindspore/nn/probability/distribution/categorical.py +1 -1
  225. mindspore/nn/probability/distribution/cauchy.py +4 -2
  226. mindspore/nn/probability/distribution/exponential.py +6 -7
  227. mindspore/nn/probability/distribution/gamma.py +2 -2
  228. mindspore/nn/probability/distribution/gumbel.py +2 -2
  229. mindspore/nn/probability/distribution/half_normal.py +5 -3
  230. mindspore/nn/probability/distribution/logistic.py +5 -3
  231. mindspore/nn/probability/distribution/poisson.py +1 -1
  232. mindspore/nn/probability/distribution/uniform.py +5 -3
  233. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  234. mindspore/nn/reinforcement/tensor_array.py +1 -1
  235. mindspore/nn/utils/init.py +13 -11
  236. mindspore/nn/wrap/__init__.py +6 -6
  237. mindspore/nn/wrap/cell_wrapper.py +181 -122
  238. mindspore/nn/wrap/grad_reducer.py +45 -36
  239. mindspore/nn/wrap/loss_scale.py +6 -7
  240. mindspore/numpy/array_creations.py +63 -65
  241. mindspore/numpy/array_ops.py +149 -144
  242. mindspore/numpy/logic_ops.py +41 -42
  243. mindspore/numpy/math_ops.py +365 -363
  244. mindspore/numpy/utils.py +17 -18
  245. mindspore/numpy/utils_const.py +5 -6
  246. mindspore/opencv_core452.dll +0 -0
  247. mindspore/opencv_imgcodecs452.dll +0 -0
  248. mindspore/opencv_imgproc452.dll +0 -0
  249. mindspore/ops/__init__.py +5 -3
  250. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  251. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  252. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  253. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  254. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  255. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  256. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  257. mindspore/ops/_register_for_op.py +0 -11
  258. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  259. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  260. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  261. mindspore/ops/_vmap/vmap_base.py +0 -2
  262. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  263. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  264. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  265. mindspore/ops/auto_generate/__init__.py +4 -3
  266. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  267. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  268. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  269. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  270. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  271. mindspore/ops/composite/__init__.py +2 -1
  272. mindspore/ops/composite/base.py +20 -25
  273. mindspore/ops/composite/math_ops.py +6 -16
  274. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  275. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  276. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  277. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  278. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  279. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  282. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  283. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  284. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  285. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  286. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  287. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  288. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  289. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  291. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  292. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  293. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  294. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  295. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  297. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  301. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  302. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  304. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  305. mindspore/ops/function/__init__.py +40 -2
  306. mindspore/ops/function/_add_attr_func.py +58 -0
  307. mindspore/ops/function/array_func.py +2089 -2403
  308. mindspore/ops/function/clip_func.py +80 -23
  309. mindspore/ops/function/debug_func.py +57 -57
  310. mindspore/ops/function/grad/__init__.py +1 -0
  311. mindspore/ops/function/grad/grad_func.py +104 -71
  312. mindspore/ops/function/image_func.py +2 -2
  313. mindspore/ops/function/linalg_func.py +47 -78
  314. mindspore/ops/function/math_func.py +4501 -3802
  315. mindspore/ops/function/nn_func.py +1726 -620
  316. mindspore/ops/function/other_func.py +159 -1
  317. mindspore/ops/function/parameter_func.py +18 -84
  318. mindspore/ops/function/random_func.py +440 -387
  319. mindspore/ops/function/reshard_func.py +4 -70
  320. mindspore/ops/function/sparse_func.py +3 -3
  321. mindspore/ops/function/sparse_unary_func.py +6 -6
  322. mindspore/ops/function/spectral_func.py +25 -58
  323. mindspore/ops/function/vmap_func.py +24 -17
  324. mindspore/ops/functional.py +22 -7
  325. mindspore/ops/functional_overload.py +1440 -0
  326. mindspore/ops/op_info_register.py +32 -244
  327. mindspore/ops/operations/__init__.py +13 -7
  328. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  329. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  330. mindspore/ops/operations/_grad_ops.py +2 -43
  331. mindspore/ops/operations/_infer_ops.py +2 -1
  332. mindspore/ops/operations/_inner_ops.py +43 -84
  333. mindspore/ops/operations/_ms_kernel.py +4 -10
  334. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  335. mindspore/ops/operations/_scalar_ops.py +3 -2
  336. mindspore/ops/operations/_sequence_ops.py +1 -1
  337. mindspore/ops/operations/_tensor_array.py +1 -1
  338. mindspore/ops/operations/array_ops.py +81 -324
  339. mindspore/ops/operations/comm_ops.py +154 -108
  340. mindspore/ops/operations/custom_ops.py +232 -78
  341. mindspore/ops/operations/debug_ops.py +153 -59
  342. mindspore/ops/operations/inner_ops.py +7 -5
  343. mindspore/ops/operations/linalg_ops.py +1 -57
  344. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  345. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  346. mindspore/ops/operations/math_ops.py +32 -234
  347. mindspore/ops/operations/nn_ops.py +210 -498
  348. mindspore/ops/operations/other_ops.py +62 -9
  349. mindspore/ops/operations/random_ops.py +13 -7
  350. mindspore/ops/operations/reshard_ops.py +1 -1
  351. mindspore/ops/operations/sparse_ops.py +2 -2
  352. mindspore/ops/primitive.py +66 -53
  353. mindspore/ops/tensor_method.py +1888 -0
  354. mindspore/ops_generate/__init__.py +0 -5
  355. mindspore/ops_generate/aclnn/__init__.py +0 -0
  356. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  357. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  358. mindspore/ops_generate/api/__init__.py +0 -0
  359. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  360. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  361. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  362. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  363. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  364. mindspore/ops_generate/api/gen_api.py +103 -0
  365. mindspore/ops_generate/api/op_api_proto.py +235 -0
  366. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  367. mindspore/ops_generate/common/__init__.py +0 -0
  368. mindspore/ops_generate/common/base_generator.py +11 -0
  369. mindspore/ops_generate/common/gen_constants.py +91 -0
  370. mindspore/ops_generate/common/gen_utils.py +348 -0
  371. mindspore/ops_generate/common/op_proto.py +473 -0
  372. mindspore/ops_generate/common/template.py +523 -0
  373. mindspore/ops_generate/gen_ops.py +22 -1069
  374. mindspore/ops_generate/op_def/__init__.py +0 -0
  375. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  376. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  377. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  378. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  379. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  380. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  381. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  382. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  383. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  384. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  385. mindspore/ops_generate/pyboost/__init__.py +0 -0
  386. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  387. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  388. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  389. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  390. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  391. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  393. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  394. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  395. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  396. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  397. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  398. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  399. mindspore/ops_generate/resources/__init__.py +0 -0
  400. mindspore/ops_generate/resources/resource_list.py +30 -0
  401. mindspore/ops_generate/resources/resource_loader.py +36 -0
  402. mindspore/ops_generate/resources/resource_manager.py +64 -0
  403. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  404. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  405. mindspore/parallel/__init__.py +7 -3
  406. mindspore/parallel/_auto_parallel_context.py +152 -34
  407. mindspore/parallel/_cell_wrapper.py +130 -15
  408. mindspore/parallel/_parallel_serialization.py +107 -5
  409. mindspore/parallel/_ps_context.py +1 -1
  410. mindspore/parallel/_recovery_context.py +7 -2
  411. mindspore/parallel/_tensor.py +142 -18
  412. mindspore/parallel/_utils.py +199 -23
  413. mindspore/parallel/algo_parameter_config.py +4 -4
  414. mindspore/parallel/auto_parallel.py +732 -0
  415. mindspore/parallel/checkpoint_convert.py +159 -0
  416. mindspore/parallel/checkpoint_transform.py +698 -35
  417. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  418. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  419. mindspore/parallel/cluster/run.py +21 -4
  420. mindspore/parallel/function/__init__.py +24 -0
  421. mindspore/parallel/function/reshard_func.py +259 -0
  422. mindspore/parallel/nn/__init__.py +25 -0
  423. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  424. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  425. mindspore/parallel/parameter_broadcast.py +25 -14
  426. mindspore/parallel/shard.py +137 -58
  427. mindspore/parallel/transform_safetensors.py +363 -305
  428. mindspore/profiler/__init__.py +22 -5
  429. mindspore/profiler/analysis/__init__.py +0 -0
  430. mindspore/profiler/analysis/parser/__init__.py +0 -0
  431. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  432. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  433. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  434. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  435. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  436. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  437. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  440. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  441. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  446. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  447. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  448. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  449. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  450. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  451. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  452. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  453. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  454. mindspore/profiler/analysis/task_manager.py +131 -0
  455. mindspore/profiler/analysis/time_converter.py +84 -0
  456. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  457. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  458. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  459. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  460. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  461. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  462. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  463. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  464. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  465. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  466. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  467. mindspore/profiler/analysis/work_flow.py +73 -0
  468. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  469. mindspore/profiler/common/command_executor.py +90 -0
  470. mindspore/profiler/common/constant.py +186 -3
  471. mindspore/profiler/common/file_manager.py +208 -0
  472. mindspore/profiler/common/log.py +130 -0
  473. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  474. mindspore/profiler/common/path_manager.py +395 -0
  475. mindspore/profiler/common/process_bar.py +168 -0
  476. mindspore/profiler/common/process_pool.py +9 -3
  477. mindspore/profiler/common/profiler_context.py +500 -0
  478. mindspore/profiler/common/profiler_info.py +304 -0
  479. mindspore/profiler/common/profiler_meta_data.py +74 -0
  480. mindspore/profiler/common/profiler_output_path.py +284 -0
  481. mindspore/profiler/common/profiler_parameters.py +251 -0
  482. mindspore/profiler/common/profiler_path_manager.py +179 -0
  483. mindspore/profiler/common/record_function.py +76 -0
  484. mindspore/profiler/common/tlv_decoder.py +76 -0
  485. mindspore/profiler/common/util.py +75 -2
  486. mindspore/profiler/dynamic_profiler.py +341 -75
  487. mindspore/profiler/envprofiler.py +163 -0
  488. mindspore/profiler/experimental_config.py +197 -0
  489. mindspore/profiler/mstx.py +242 -0
  490. mindspore/profiler/platform/__init__.py +21 -0
  491. mindspore/profiler/platform/base_profiler.py +40 -0
  492. mindspore/profiler/platform/cpu_profiler.py +124 -0
  493. mindspore/profiler/platform/gpu_profiler.py +74 -0
  494. mindspore/profiler/platform/npu_profiler.py +335 -0
  495. mindspore/profiler/profiler.py +1073 -90
  496. mindspore/profiler/profiler_action_controller.py +187 -0
  497. mindspore/profiler/profiler_interface.py +118 -0
  498. mindspore/profiler/schedule.py +243 -0
  499. mindspore/rewrite/api/node.py +15 -13
  500. mindspore/rewrite/api/symbol_tree.py +2 -3
  501. mindspore/run_check/_check_version.py +27 -20
  502. mindspore/run_check/run_check.py +1 -1
  503. mindspore/runtime/__init__.py +37 -0
  504. mindspore/runtime/device.py +27 -0
  505. mindspore/runtime/event.py +209 -0
  506. mindspore/runtime/executor.py +177 -0
  507. mindspore/runtime/memory.py +409 -0
  508. mindspore/runtime/stream.py +460 -0
  509. mindspore/runtime/thread_bind_core.py +401 -0
  510. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  511. mindspore/swresample-4.dll +0 -0
  512. mindspore/swscale-6.dll +0 -0
  513. mindspore/tinyxml2.dll +0 -0
  514. mindspore/train/__init__.py +8 -8
  515. mindspore/train/_utils.py +88 -25
  516. mindspore/train/amp.py +9 -5
  517. mindspore/train/callback/__init__.py +2 -2
  518. mindspore/train/callback/_callback.py +2 -16
  519. mindspore/train/callback/_checkpoint.py +53 -55
  520. mindspore/train/callback/_cluster_monitor.py +14 -18
  521. mindspore/train/callback/_early_stop.py +1 -1
  522. mindspore/train/callback/_flops_collector.py +103 -68
  523. mindspore/train/callback/_history.py +8 -5
  524. mindspore/train/callback/_lambda_callback.py +2 -2
  525. mindspore/train/callback/_landscape.py +0 -3
  526. mindspore/train/callback/_loss_monitor.py +2 -1
  527. mindspore/train/callback/_on_request_exit.py +6 -5
  528. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  529. mindspore/train/callback/_summary_collector.py +52 -19
  530. mindspore/train/callback/_time_monitor.py +2 -1
  531. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  532. mindspore/train/data_sink.py +25 -2
  533. mindspore/train/dataset_helper.py +15 -16
  534. mindspore/train/loss_scale_manager.py +8 -7
  535. mindspore/train/metrics/accuracy.py +3 -3
  536. mindspore/train/metrics/confusion_matrix.py +9 -9
  537. mindspore/train/metrics/error.py +3 -3
  538. mindspore/train/metrics/hausdorff_distance.py +4 -4
  539. mindspore/train/metrics/mean_surface_distance.py +3 -3
  540. mindspore/train/metrics/metric.py +0 -12
  541. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  542. mindspore/train/metrics/precision.py +11 -10
  543. mindspore/train/metrics/recall.py +9 -9
  544. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  545. mindspore/train/mind_ir_pb2.py +174 -46
  546. mindspore/train/model.py +184 -113
  547. mindspore/train/serialization.py +622 -978
  548. mindspore/train/summary/_summary_adapter.py +2 -2
  549. mindspore/train/summary/summary_record.py +2 -3
  550. mindspore/train/train_thor/model_thor.py +1 -1
  551. mindspore/turbojpeg.dll +0 -0
  552. mindspore/utils/__init__.py +6 -3
  553. mindspore/utils/dryrun.py +140 -0
  554. mindspore/utils/hooks.py +81 -0
  555. mindspore/utils/runtime_execution_order_check.py +550 -0
  556. mindspore/utils/utils.py +138 -4
  557. mindspore/version.py +1 -1
  558. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  559. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  561. mindspore/_install_custom.py +0 -43
  562. mindspore/common/_register_for_adapter.py +0 -74
  563. mindspore/common/_tensor_overload.py +0 -139
  564. mindspore/mindspore_np_dtype.dll +0 -0
  565. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  566. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  567. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  568. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  569. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  570. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  571. mindspore/ops_generate/gen_utils.py +0 -209
  572. mindspore/ops_generate/op_proto.py +0 -145
  573. mindspore/ops_generate/template.py +0 -261
  574. mindspore/profiler/envprofiling.py +0 -254
  575. mindspore/profiler/profiling.py +0 -1926
  576. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  577. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Module for generating Python primitive operator definitions from specifications.
17
+ """
18
+
19
+ import os
20
+
21
+ import common.gen_constants as K
22
+ import common.gen_utils as gen_utils
23
+ import common.template as template
24
+ from common.base_generator import BaseGenerator
25
+ from common.op_proto import OpProto
26
+ from common.template import Template
27
+ from pyboost import pyboost_utils
28
+
29
+
30
+ class OpPrimPyGenerator(BaseGenerator):
31
+ """
32
+ Generates Python code for primitive operators based on provided specifications.
33
+ """
34
+
35
+ def __init__(self):
36
+ """
37
+ Initializes the generator with a template for defining operator primitive classes.
38
+ """
39
+ self.op_prim_class_define_template = template.OP_PRIM_CLASS_DEFINE_TEMPLATE
40
+
41
+ def generate(self, work_path, op_protos, doc_dict, file_pre):
42
+ """
43
+ Generates Python code for operator primitives and saves it to a file.
44
+
45
+ Args:
46
+ work_path (str): The directory to save the generated files.
47
+ op_protos (list): A list of operator prototypes.
48
+ doc_dict (dict): A dictionary containing documentation strings.
49
+ file_pre (str): The prefix for the generated file names.
50
+ """
51
+ gen_py = ""
52
+ for op_proto in op_protos:
53
+ if op_proto.op_class.disable:
54
+ continue
55
+
56
+ inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = (
57
+ self._process_args(op_proto))
58
+
59
+ # add class description
60
+ class_desc = self._generate_class_desc(op_proto, inputs_args, init_args, doc_dict)
61
+
62
+ # add signature
63
+ signature_code = self._generate_py_op_signature(op_proto, inputs_args, inputs_default)
64
+
65
+ # add deprecated
66
+ deprecated_code = generate_py_op_deprecated(op_proto.op_deprecated)
67
+
68
+ # add __init__ method code
69
+ init_method = self._generate_init_code(args_assign, init_args_with_default, op_proto)
70
+
71
+ # add __call__ method code
72
+ call_method = self._generate_call_code(args_handlers, init_args, inputs_args, inputs_default, op_proto)
73
+
74
+ # generate op prim class define
75
+ op_prim_class_define = self.op_prim_class_define_template.replace(class_name=op_proto.op_class.name,
76
+ class_desc=class_desc,
77
+ signature_code=signature_code,
78
+ deprecated_code=deprecated_code,
79
+ init_method=init_method,
80
+ call_method=call_method)
81
+ op_prim_class_define += "\n" if call_method.endswith("\n") else ""
82
+ gen_py += op_prim_class_define
83
+
84
+ # add prim_op_object
85
+ if not init_args:
86
+ gen_py += f"\n\n{op_proto.op_name}_op={op_proto.op_class.name}()\n"
87
+
88
+ pyboost_import_header = self.generate_pyboost_import_header(op_protos)
89
+ res_str = template.PY_LICENSE_STR + \
90
+ template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
91
+
92
+ save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
93
+ file_name = f"{file_pre}_ops_prim.py"
94
+ gen_utils.save_file(save_path, file_name, res_str)
95
+
96
+ def generate_pyboost_import_header(self, op_protos) -> str:
97
+ """
98
+ Generates import statements for PyBoost primitives.
99
+
100
+ Args:
101
+ op_protos (list): A list of operator prototypes.
102
+
103
+ Returns:
104
+ str: A string containing import statements.
105
+ """
106
+ pyboost_import_header = ''
107
+ import_pyboost = Template("from mindspore._c_expression import $var\n")
108
+ for op_proto in op_protos:
109
+ if op_proto.op_dispatch and op_proto.op_dispatch.enable:
110
+ header = import_pyboost.replace(var=pyboost_utils.get_pyboost_name(op_proto.op_name))
111
+ pyboost_import_header += header
112
+ return pyboost_import_header
113
+
114
+ def _process_args(self, op_proto: OpProto):
115
+ """
116
+ Processes operator arguments to categorize them for code generation.
117
+
118
+ Args:
119
+ op_proto (OpProto): The operator prototype.
120
+
121
+ Returns:
122
+ tuple: A tuple containing processed arguments.
123
+ """
124
+ inputs_name = []
125
+ args_name = []
126
+ args_assign = []
127
+ inputs_default = {}
128
+ init_args_with_default = []
129
+ args_handlers = {}
130
+
131
+ for arg in op_proto.op_args:
132
+ # step1: get args infos:
133
+ if arg.is_prim_init:
134
+ # step1.1: get args name:
135
+ args_name.append(arg.arg_name)
136
+ # step1.2: get args assign with default value:
137
+ if arg.default is not None:
138
+ init_args_with_default.append(f"""{arg.arg_name}={arg.default}""")
139
+ else:
140
+ init_args_with_default.append(f"""{arg.arg_name}""")
141
+
142
+ # step1.3: get args set prim arg expression:
143
+ assign_str = self._get_assign_str_by_type_it(op_proto.op_class.name, arg)
144
+ if arg.arg_handler:
145
+ assign_str = f""" self._set_prim_arg_with_handler("{arg.arg_name}", {assign_str}, {arg.arg_handler})"""
146
+ else:
147
+ assign_str = f""" self._set_prim_arg("{arg.arg_name}", {assign_str})"""
148
+ args_assign.append(assign_str)
149
+ # step2: get inputs infos:
150
+ else:
151
+ # step2.1: get inputs name:
152
+ inputs_name.append(arg.arg_name)
153
+
154
+ # step2.2: get default value of inputs:
155
+ if arg.default is not None:
156
+ inputs_default[arg.arg_name] = arg.default
157
+
158
+ # step2.3: get args_handler functions for inputs
159
+ if arg.arg_handler:
160
+ args_handlers[arg.arg_name] = arg.arg_handler
161
+
162
+ return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
163
+
164
+ def _get_assign_str_by_type_it(self, class_name, arg):
165
+ """
166
+ Generates assignment string with type casting.
167
+
168
+ Args:
169
+ class_name (str): The name of the class.
170
+ arg (OpArg): The operator argument.
171
+
172
+ Returns:
173
+ str: A string representing the assignment.
174
+ """
175
+ assign_str = ""
176
+ type_cast = arg.type_cast
177
+ if type_cast:
178
+ assign_str += f"type_it('{class_name}', '{arg.arg_name}', {arg.arg_name}, "
179
+ if len(type_cast) == 1:
180
+ assign_str += gen_utils.get_type_str(type_cast[0]) + ', '
181
+ else:
182
+ assign_str += '(' + ', '.join(gen_utils.get_type_str(ct) for ct in type_cast) + '), '
183
+ assign_str += gen_utils.get_type_str(arg.arg_dtype) + ')'
184
+ else:
185
+ assign_str = arg.arg_name
186
+ return assign_str
187
+
188
+ def _generate_class_desc(self, op_proto: OpProto, input_args, init_args, doc_dic):
189
+ """
190
+ Generates a class description based on the operator prototype.
191
+
192
+ Args:
193
+ op_proto (OpProto): The operator prototype.
194
+ input_args (list): List of input argument names.
195
+ init_args (list): List of initialization argument names.
196
+ doc_dic (dict): Documentation dictionary.
197
+
198
+ Returns:
199
+ str: A string containing the class description.
200
+ """
201
+ if op_proto.op_function and op_proto.op_function.disable:
202
+ # if function disabled, function name is equal to operator_name
203
+ return gen_utils.get_op_description(op_proto.op_name, doc_dic)
204
+
205
+ # If function is a released API, refer to the function doc.
206
+ init_args_str = ", ".join(init_args)
207
+ input_args_str = ", ".join(input_args)
208
+ args_str = ", ".join(input_args + init_args)
209
+
210
+ description_template = Template(template.PRIMITIVE_CLASS_DESC)
211
+ description_str = description_template.replace(class_name=op_proto.op_class.name,
212
+ init_args_str=init_args_str,
213
+ input_args_str=input_args_str,
214
+ func_name=op_proto.op_function.name,
215
+ args_str=args_str)
216
+ return description_str
217
+
218
+ def _generate_init_code(self, args_assign, init_args_with_default, op_proto: OpProto):
219
+ """
220
+ Generates the __init__ method code for the operator primitive class.
221
+
222
+ Args:
223
+ args_assign (list): List of argument assignment strings.
224
+ init_args_with_default (list): List of initialization arguments with default values.
225
+ op_proto (OpProto): The operator prototype.
226
+
227
+ Returns:
228
+ str: A string containing the __init__ method code.
229
+ """
230
+ init_code_str = ""
231
+ init_args_list_str = ""
232
+ if init_args_with_default:
233
+ init_args_list_str += ", " + f"""{", ".join(init_args_with_default) if init_args_with_default else ""}"""
234
+ init_code = "\n".join(args_assign)
235
+ init_code = self._get_init_code(init_code, op_proto)
236
+ init_code_str += f" @prim_arg_register\n"
237
+ init_code_str += f" def __init__(self{init_args_list_str}):\n"
238
+ init_code_str += f"{init_code}\n"
239
+ init_code_str += f"\n"
240
+ return init_code_str
241
+
242
+ def _get_init_code(self, init_code, op_proto: OpProto):
243
+ """
244
+ Generates additional initialization code for the operator primitive class.
245
+
246
+ Args:
247
+ init_code (str): Existing initialization code.
248
+ op_proto (OpProto): The operator prototype.
249
+
250
+ Returns:
251
+ str: A string containing additional initialization code.
252
+ """
253
+ labels_dic = op_proto.op_labels
254
+ if labels_dic:
255
+ if init_code:
256
+ init_code += "\n"
257
+ init_code += "\n".join([f""" self.add_prim_attr("{k}", {v})""" for k, v in labels_dic.items()])
258
+
259
+ return init_code if init_code else f""" pass"""
260
+
261
+ def _generate_call_code(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
262
+ """
263
+ Generates the __call__ method code for the operator primitive class.
264
+
265
+ Args:
266
+ args_handlers (dict): Dictionary of argument handlers.
267
+ init_args (list): List of initialization argument names.
268
+ inputs_args (list): List of input argument names.
269
+ inputs_default (dict): Dictionary of default input values.
270
+ op_proto (OpProto): The operator prototype.
271
+
272
+ Returns:
273
+ str: A string containing the __call__ method code.
274
+ """
275
+ call_code_str = ""
276
+ call_args = []
277
+ for name in inputs_args:
278
+ call_args.append(f"{name}={inputs_default[name]}" if name in inputs_default else name)
279
+ call_method_args_str = ", ".join(call_args)
280
+ call_method_body_str = self._get_call_method_body_str(args_handlers, init_args, inputs_args, inputs_default,
281
+ op_proto)
282
+ call_code_str += f""" def __call__(self, {call_method_args_str}):"""
283
+ call_code_str += f"""{call_method_body_str}"""
284
+ return call_code_str
285
+
286
+ def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
287
+ """
288
+ Generates the body of the __call__ method.
289
+
290
+ Args:
291
+ args_handlers (dict): Dictionary of argument handlers.
292
+ init_args (list): List of initialization argument names.
293
+ inputs_args (list): List of input argument names.
294
+ inputs_default (dict): Dictionary of default input values.
295
+ op_proto (OpProto): The operator prototype.
296
+
297
+ Returns:
298
+ str: A string containing the body of the call method.
299
+ """
300
+ call_args_list_str = ""
301
+ if inputs_args:
302
+ args_with_handler = []
303
+ for arg in inputs_args:
304
+ if arg in args_handlers:
305
+ is_optional = inputs_default.get(arg) == "None"
306
+ args_with_handler.append(
307
+ _generate_arg_handler(op_proto.op_class.name, arg, args_handlers[arg], is_optional))
308
+ else:
309
+ args_with_handler.append(arg)
310
+ call_args_list_str += ", ".join(args_with_handler)
311
+ if init_args:
312
+ call_args_list_str += ", "
313
+ call_args_list_str += ", ".join([f'self.{arg}' for arg in init_args])
314
+
315
+ call_method_body_str = ""
316
+ is_pyboost = op_proto.op_dispatch and op_proto.op_dispatch.enable
317
+ if is_pyboost:
318
+ call_method_body_str += f"""
319
+ # Add for jit context.
320
+ if jit_context() and jit_context().compiled:
321
+ return None"""
322
+ pyboost_func_name = pyboost_utils.get_pyboost_name(op_proto.op_name)
323
+ call_method_body_str += f"""
324
+ res = {pyboost_func_name}(self, [{call_args_list_str}])"""
325
+ call_method_body_str += f"""
326
+ # Add for jit context.
327
+ if jit_context():
328
+ return jit_context().run_op(self, res, {call_args_list_str})
329
+ return res\n"""
330
+ else:
331
+ call_method_body_str += f"""
332
+ return super().__call__({call_args_list_str})\n"""
333
+ return call_method_body_str
334
+
335
+ def _generate_py_op_signature(self, op_proto: OpProto, args_name, args_default):
336
+ """
337
+ Generates the __mindspore_signature__ for the operator.
338
+
339
+ Args:
340
+ op_proto (OpProto): The operator prototype.
341
+ args_name (list): List of argument names.
342
+ args_default (dict): Dictionary of default argument values.
343
+
344
+ Returns:
345
+ str: A string containing the __mindspore_signature__ code.
346
+ """
347
+ op_name = op_proto.op_name
348
+ args_signature = op_proto.op_args_signature
349
+
350
+ if args_signature is None and not args_default:
351
+ return ''
352
+
353
+ signature_code = f"""\n __mindspore_signature__ = """
354
+
355
+ # Init rw.
356
+ read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
357
+ _check_signature_arg_valid(op_name, write_list, args_name)
358
+ _check_signature_arg_valid(op_name, read_list, args_name)
359
+ _check_signature_arg_valid(op_name, ref_list, args_name)
360
+
361
+ # Init dtype group.
362
+ same_dtype_groups, dtype_count = gen_utils.get_same_dtype_groups(args_signature, args_name)
363
+ _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
364
+
365
+ # Only one dtype_group is set.
366
+ if dtype_count == 1 and not any([write_list, read_list, ref_list, args_default]):
367
+ signature_code += '('
368
+ for _ in range(len(args_name) - 1):
369
+ signature_code += 'sig.sig_dtype.T, '
370
+ signature_code += 'sig.sig_dtype.T)\n'
371
+ return signature_code
372
+
373
+ # Set sig.make_sig.
374
+ signature_code += f""" (\n"""
375
+ for arg_name in args_name:
376
+ signature_code += f""" sig.make_sig('{arg_name}'"""
377
+ signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
378
+ if arg_name in same_dtype_groups:
379
+ signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
380
+ if arg_name in args_default:
381
+ signature_code += f""", default=""" + str(args_default[arg_name])
382
+ signature_code += f"""),\n"""
383
+ signature_code += f""" )\n"""
384
+ return signature_code
385
+
386
+
387
+ def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
388
+ """
389
+ Validates that all signature arguments are present in the list of argument names.
390
+
391
+ Args:
392
+ op_name (str): The name of the operator.
393
+ sig_arg_names (list): List of signature argument names.
394
+ args_names (list): List of actual argument names.
395
+
396
+ Raises:
397
+ ValueError: If a signature argument is not found in the list of argument names.
398
+ """
399
+ for sig_arg_name in sig_arg_names:
400
+ if sig_arg_name not in args_names:
401
+ raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
402
+
403
+
404
+ def signature_get_dtype_label(index):
405
+ """
406
+ Generates the label for the data type in the signature.
407
+
408
+ Args:
409
+ index (int): The index of the data type.
410
+
411
+ Returns:
412
+ str: The label string for the data type.
413
+ """
414
+ dtype_index = ''
415
+ if index > 0:
416
+ dtype_index = f"""{index}"""
417
+ return f"""dtype=sig.sig_dtype.T{dtype_index}"""
418
+
419
+
420
+ def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
421
+ """
422
+ Determines the read-write label for an argument in the signature.
423
+
424
+ Args:
425
+ arg_name (str): The name of the argument.
426
+ write_list (list): List of arguments that are writable.
427
+ read_list (list): List of arguments that are readable.
428
+ ref_list (list): List of arguments that are references.
429
+
430
+ Returns:
431
+ str: The read-write label for the argument.
432
+ """
433
+ for rw_arg_name in write_list:
434
+ if rw_arg_name == arg_name:
435
+ return ', sig.sig_rw.RW_WRITE'
436
+ for read_arg_name in read_list:
437
+ if read_arg_name == arg_name:
438
+ return ', sig.sig_rw.RW_READ'
439
+ for ref_arg_name in ref_list:
440
+ if ref_arg_name == arg_name:
441
+ return ', sig.sig_rw.RW_REF'
442
+ return ''
443
+
444
+
445
+ def generate_py_op_deprecated(deprecated):
446
+ """
447
+ Generates the deprecated decorator for an operator.
448
+
449
+ Args:
450
+ deprecated (dict): The deprecation information.
451
+
452
+ Returns:
453
+ str: A string containing the deprecated decorator.
454
+ """
455
+ if deprecated is None:
456
+ return ''
457
+ version = deprecated.get("version")
458
+ if version is None:
459
+ raise ValueError("The version of deprecated can't be None.")
460
+ substitute = deprecated.get("substitute")
461
+ if substitute is None:
462
+ raise ValueError("The substitute of deprecated can't be None.")
463
+ use_substitute = deprecated.get("use_substitute")
464
+ if use_substitute is None:
465
+ raise ValueError("The use_substitute of deprecated can't be None.")
466
+ if use_substitute is not True and use_substitute is not False:
467
+ raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
468
+
469
+ deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
470
+ return deprecated
471
+
472
+
473
+ def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
474
+ """
475
+ Generates the argument handler call for an argument.
476
+
477
+ Args:
478
+ class_name (str): The name of the class.
479
+ arg (str): The name of the argument.
480
+ arg_handler (str): The handler function for the argument.
481
+ is_optional (bool): Indicates whether the argument is optional.
482
+
483
+ Returns:
484
+ str: The argument handler call string.
485
+ """
486
+ arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
487
+ if is_optional:
488
+ arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
489
+ return arg_handler_call
File without changes
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ This module provides a generator class for creating C++ implementation files for AutoGrad functionality.
17
+ """
18
+
19
+ import os
20
+
21
+ import common.template as template
22
+ from common.template import Template
23
+ import common.gen_constants as K
24
+ from common.gen_utils import save_file
25
+ from common.base_generator import BaseGenerator
26
+ from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output
27
+
28
+
29
+ class AutoGradImplGenerator(BaseGenerator):
30
+ """
31
+ Generates C++ implementation files for the AutoGrad functionality based on operator prototypes.
32
+ """
33
+
34
+ def __init__(self):
35
+ """
36
+ Initialize the AutoGrad implementation generator with templates for code generation.
37
+ """
38
+ self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
39
+ self.AUTO_GRAD_IMPL_CC_TEMPLATE = template.AUTO_GRAD_IMPL_CC_TEMPLATE
40
+ self.DO_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_GRAD_FUNCTION_BODY_TEMPLATE
41
+ self.auto_grad_reg_template = Template("const_cast<kernel::pyboost::${class_name}GradFunc&>(" + \
42
+ "kernel::pyboost::AutoGradFactory::Get()." + \
43
+ "ops_auto_grad_registers().${class_name}GradFuncObj) = " + \
44
+ "kernel::pyboost::${class_name}GradFunc(DoGrad${class_name});")
45
+ self.do_grad_op_args_with_type = Template(
46
+ "const kernel::pyboost::OpPtr &op, ${input_args_with_type}"
47
+ )
48
+
49
+ def generate(self, work_path, op_protos):
50
+ """
51
+ Generate the AutoGrad implementation file.
52
+
53
+ Args:
54
+ work_path (str): The directory where the generated file should be saved.
55
+ op_protos (list): A list of operator prototypes used to generate the implementation.
56
+ """
57
+ auto_grad_reg_list = []
58
+ do_grad_op_list = []
59
+ ops_inc_head_set = set()
60
+ for op_proto in op_protos:
61
+ if op_proto.op_dispatch is None or op_proto.op_dispatch.is_comm_op:
62
+ continue
63
+ auto_grad_reg_list.append(self.auto_grad_reg_template.replace(class_name=op_proto.op_class.name))
64
+ do_grad_op_list.append(self._get_single_do_grad_op(op_proto))
65
+ ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
66
+ pyboost_func_h_str = self.AUTO_GRAD_IMPL_CC_TEMPLATE.replace(do_grad_op=do_grad_op_list,
67
+ auto_grad_reg=auto_grad_reg_list,
68
+ ops_inc=list(sorted(ops_inc_head_set)))
69
+ save_path = os.path.join(work_path, K.PYBOOST_AUTO_GRAD_FUNC_GEN_PATH)
70
+ file_name = "auto_grad_impl.cc"
71
+ save_file(save_path, file_name, pyboost_func_h_str)
72
+
73
+ def _get_single_do_grad_op(self, op_proto):
74
+ """
75
+ Generate the DoGrad function for a single operator prototype.
76
+
77
+ Args:
78
+ op_proto: The operator prototype for which the DoGrad function is generated.
79
+
80
+ Returns:
81
+ str: The generated DoGrad function string.
82
+ """
83
+ input_args_str = self._get_input_args(op_proto, False, False)
84
+ input_args_with_optional_str = self._get_input_args(op_proto, False, True)
85
+ input_args_with_type_str = self._get_input_args(op_proto, True, False)
86
+ multi_output_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
87
+ view_arg_str = self._get_view_str(op_proto.op_view, input_args_str)
88
+ grad_args_with_type_str = self.do_grad_op_args_with_type.replace(input_args_with_type=input_args_with_type_str)
89
+ op_def_name_str = "g" + op_proto.op_class.name
90
+ bprop_expander = "true" if op_proto.bprop_expander else "false"
91
+ return self.DO_GRAD_FUNCTION_BODY_TEMPLATE.replace(class_name=op_proto.op_class.name,
92
+ grad_args_with_type=grad_args_with_type_str,
93
+ grad_input_args=input_args_str,
94
+ grad_input_args_with_optional=input_args_with_optional_str,
95
+ is_multi=multi_output_str,
96
+ view_arg=view_arg_str,
97
+ op_def_name=op_def_name_str,
98
+ bprop_expander=bprop_expander)
99
+
100
+ def _get_input_args(self, op_proto, has_type, with_optional):
101
+ """
102
+ Get the input arguments for the DoGrad function.
103
+
104
+ Args:
105
+ op_proto: The operator prototype.
106
+ has_type (bool): Whether to include type information for the arguments.
107
+
108
+ Returns:
109
+ list: A list of input arguments for the DoGrad function.
110
+ """
111
+ args_list = []
112
+ for op_arg in op_proto.op_args:
113
+ input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional_param(op_arg))
114
+ if has_type:
115
+ args_list.append(f"const {input_dtype} &{op_arg.arg_name}_tensor")
116
+ else:
117
+ if not with_optional and is_optional_param(op_arg):
118
+ args_list.append(f"OptionalToValue({op_arg.arg_name}_tensor)")
119
+ else:
120
+ args_list.append(f"{op_arg.arg_name}_tensor")
121
+ return args_list
122
+
123
+ def _get_view_str(self, is_view_op: bool, grad_args: list):
124
+ """
125
+ Get the view argument string for a DoGrad function.
126
+
127
+ Args:
128
+ is_view_op (bool): Whether the operator is a view operator.
129
+ grad_args (list): A list of gradient arguments.
130
+
131
+ Returns:
132
+ str: The view argument string.
133
+ """
134
+ view_arg_str = ''
135
+ for i, grad_arg in enumerate(grad_args):
136
+ if is_view_op and i == 0:
137
+ view_arg_str = ", " + grad_arg
138
+ break
139
+ return view_arg_str