mindspore 2.3.0rc1__cp37-none-any.whl → 2.3.0rc2__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (316) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/bin/cache_admin +0 -0
  12. mindspore/bin/cache_server +0 -0
  13. mindspore/boost/boost_cell_wrapper.py +1 -1
  14. mindspore/boost/group_loss_scale_manager.py +1 -1
  15. mindspore/common/__init__.py +4 -2
  16. mindspore/common/_register_for_recompute.py +48 -0
  17. mindspore/common/_stub_tensor.py +1 -0
  18. mindspore/common/api.py +56 -4
  19. mindspore/common/dtype.py +5 -3
  20. mindspore/common/dump.py +2 -2
  21. mindspore/common/hook_handle.py +51 -4
  22. mindspore/common/initializer.py +1 -1
  23. mindspore/common/jit_config.py +17 -6
  24. mindspore/common/parameter.py +7 -2
  25. mindspore/common/recompute.py +247 -0
  26. mindspore/common/sparse_tensor.py +2 -2
  27. mindspore/common/symbol.py +1 -1
  28. mindspore/common/tensor.py +74 -36
  29. mindspore/communication/__init__.py +3 -3
  30. mindspore/communication/management.py +30 -30
  31. mindspore/context.py +28 -15
  32. mindspore/dataset/__init__.py +5 -5
  33. mindspore/dataset/audio/__init__.py +2 -2
  34. mindspore/dataset/audio/transforms.py +51 -51
  35. mindspore/dataset/callback/ds_callback.py +2 -2
  36. mindspore/dataset/engine/cache_client.py +1 -1
  37. mindspore/dataset/engine/datasets.py +3 -3
  38. mindspore/dataset/engine/datasets_audio.py +14 -14
  39. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  40. mindspore/dataset/engine/datasets_text.py +38 -38
  41. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  42. mindspore/dataset/engine/datasets_vision.py +68 -68
  43. mindspore/dataset/text/__init__.py +3 -3
  44. mindspore/dataset/text/transforms.py +26 -26
  45. mindspore/dataset/transforms/__init__.py +1 -1
  46. mindspore/dataset/vision/__init__.py +3 -3
  47. mindspore/dataset/vision/transforms.py +92 -92
  48. mindspore/dataset/vision/utils.py +1 -1
  49. mindspore/experimental/optim/adadelta.py +2 -2
  50. mindspore/experimental/optim/adagrad.py +2 -2
  51. mindspore/experimental/optim/adam.py +2 -2
  52. mindspore/experimental/optim/adamax.py +2 -2
  53. mindspore/experimental/optim/adamw.py +2 -2
  54. mindspore/experimental/optim/asgd.py +2 -2
  55. mindspore/experimental/optim/lr_scheduler.py +24 -20
  56. mindspore/experimental/optim/nadam.py +2 -2
  57. mindspore/experimental/optim/optimizer.py +1 -1
  58. mindspore/experimental/optim/radam.py +2 -2
  59. mindspore/experimental/optim/rmsprop.py +2 -2
  60. mindspore/experimental/optim/rprop.py +2 -2
  61. mindspore/experimental/optim/sgd.py +2 -2
  62. mindspore/hal/stream.py +2 -0
  63. mindspore/include/mindapi/base/types.h +5 -0
  64. mindspore/lib/libdnnl.so.2 +0 -0
  65. mindspore/lib/libmindspore.so +0 -0
  66. mindspore/lib/libmindspore_backend.so +0 -0
  67. mindspore/lib/libmindspore_common.so +0 -0
  68. mindspore/lib/libmindspore_core.so +0 -0
  69. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  70. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  71. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  72. mindspore/lib/libmindspore_shared_lib.so +0 -0
  73. mindspore/lib/libopencv_core.so.4.5 +0 -0
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  77. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  78. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  79. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  80. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  81. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  170. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  171. mindspore/log.py +2 -2
  172. mindspore/mint/__init__.py +457 -0
  173. mindspore/mint/nn/__init__.py +430 -0
  174. mindspore/mint/nn/functional.py +424 -0
  175. mindspore/mint/optim/__init__.py +24 -0
  176. mindspore/mint/optim/adamw.py +186 -0
  177. mindspore/multiprocessing/__init__.py +4 -0
  178. mindspore/nn/__init__.py +3 -0
  179. mindspore/nn/cell.py +51 -47
  180. mindspore/nn/extend/__init__.py +29 -0
  181. mindspore/nn/extend/basic.py +140 -0
  182. mindspore/nn/extend/embedding.py +143 -0
  183. mindspore/nn/extend/layer/__init__.py +27 -0
  184. mindspore/nn/extend/layer/normalization.py +107 -0
  185. mindspore/nn/extend/pooling.py +117 -0
  186. mindspore/nn/generator.py +297 -0
  187. mindspore/nn/layer/basic.py +109 -1
  188. mindspore/nn/layer/container.py +2 -2
  189. mindspore/nn/layer/conv.py +6 -6
  190. mindspore/nn/layer/embedding.py +1 -1
  191. mindspore/nn/layer/normalization.py +21 -43
  192. mindspore/nn/layer/padding.py +4 -0
  193. mindspore/nn/optim/ada_grad.py +2 -2
  194. mindspore/nn/optim/adadelta.py +1 -1
  195. mindspore/nn/optim/adafactor.py +1 -1
  196. mindspore/nn/optim/adam.py +7 -7
  197. mindspore/nn/optim/adamax.py +2 -2
  198. mindspore/nn/optim/adasum.py +2 -2
  199. mindspore/nn/optim/asgd.py +2 -2
  200. mindspore/nn/optim/ftrl.py +1 -1
  201. mindspore/nn/optim/lamb.py +3 -3
  202. mindspore/nn/optim/lars.py +1 -1
  203. mindspore/nn/optim/lazyadam.py +2 -2
  204. mindspore/nn/optim/momentum.py +2 -2
  205. mindspore/nn/optim/optimizer.py +2 -2
  206. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  207. mindspore/nn/optim/rmsprop.py +2 -2
  208. mindspore/nn/optim/rprop.py +2 -2
  209. mindspore/nn/optim/sgd.py +2 -2
  210. mindspore/nn/optim/thor.py +2 -2
  211. mindspore/nn/wrap/cell_wrapper.py +9 -9
  212. mindspore/nn/wrap/grad_reducer.py +5 -5
  213. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  214. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  215. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  216. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  217. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  218. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  219. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  220. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  221. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  222. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  223. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  224. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  225. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  226. mindspore/ops/extend/__init__.py +9 -1
  227. mindspore/ops/extend/array_func.py +134 -27
  228. mindspore/ops/extend/math_func.py +3 -3
  229. mindspore/ops/extend/nn_func.py +363 -2
  230. mindspore/ops/function/__init__.py +19 -2
  231. mindspore/ops/function/array_func.py +463 -439
  232. mindspore/ops/function/clip_func.py +7 -18
  233. mindspore/ops/function/grad/grad_func.py +5 -5
  234. mindspore/ops/function/linalg_func.py +4 -4
  235. mindspore/ops/function/math_func.py +260 -243
  236. mindspore/ops/function/nn_func.py +825 -62
  237. mindspore/ops/function/random_func.py +73 -4
  238. mindspore/ops/function/sparse_unary_func.py +1 -1
  239. mindspore/ops/function/vmap_func.py +1 -1
  240. mindspore/ops/functional.py +2 -2
  241. mindspore/ops/op_info_register.py +1 -31
  242. mindspore/ops/operations/__init__.py +2 -3
  243. mindspore/ops/operations/_grad_ops.py +2 -107
  244. mindspore/ops/operations/_inner_ops.py +5 -5
  245. mindspore/ops/operations/_sequence_ops.py +2 -2
  246. mindspore/ops/operations/array_ops.py +11 -233
  247. mindspore/ops/operations/comm_ops.py +32 -32
  248. mindspore/ops/operations/custom_ops.py +7 -89
  249. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  250. mindspore/ops/operations/math_ops.py +13 -163
  251. mindspore/ops/operations/nn_ops.py +9 -316
  252. mindspore/ops/operations/random_ops.py +1 -1
  253. mindspore/ops/operations/sparse_ops.py +3 -3
  254. mindspore/ops/primitive.py +2 -2
  255. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  256. mindspore/ops_generate/arg_handler.py +24 -0
  257. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  258. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  259. mindspore/ops_generate/pyboost_utils.py +2 -17
  260. mindspore/parallel/__init__.py +3 -2
  261. mindspore/parallel/_auto_parallel_context.py +106 -1
  262. mindspore/parallel/_parallel_serialization.py +34 -2
  263. mindspore/parallel/_utils.py +16 -0
  264. mindspore/parallel/algo_parameter_config.py +4 -4
  265. mindspore/parallel/checkpoint_transform.py +249 -77
  266. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  267. mindspore/parallel/parameter_broadcast.py +1 -1
  268. mindspore/parallel/shard.py +1 -1
  269. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  271. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  272. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  273. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  274. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  275. mindspore/profiler/parser/profiler_info.py +11 -1
  276. mindspore/profiler/profiling.py +13 -5
  277. mindspore/rewrite/api/node.py +12 -12
  278. mindspore/rewrite/api/symbol_tree.py +11 -11
  279. mindspore/run_check/_check_version.py +1 -1
  280. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  281. mindspore/train/amp.py +4 -4
  282. mindspore/train/anf_ir_pb2.py +8 -2
  283. mindspore/train/callback/_backup_and_restore.py +2 -2
  284. mindspore/train/callback/_callback.py +4 -4
  285. mindspore/train/callback/_checkpoint.py +2 -2
  286. mindspore/train/callback/_early_stop.py +2 -2
  287. mindspore/train/callback/_landscape.py +4 -4
  288. mindspore/train/callback/_loss_monitor.py +2 -2
  289. mindspore/train/callback/_on_request_exit.py +2 -2
  290. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  291. mindspore/train/callback/_summary_collector.py +2 -2
  292. mindspore/train/callback/_time_monitor.py +2 -2
  293. mindspore/train/dataset_helper.py +8 -3
  294. mindspore/train/loss_scale_manager.py +2 -2
  295. mindspore/train/metrics/metric.py +3 -3
  296. mindspore/train/mind_ir_pb2.py +22 -17
  297. mindspore/train/model.py +15 -15
  298. mindspore/train/serialization.py +18 -18
  299. mindspore/train/summary/summary_record.py +7 -7
  300. mindspore/train/train_thor/convert_utils.py +3 -3
  301. mindspore/version.py +1 -1
  302. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  303. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
  304. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  305. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  313. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  314. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  315. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """mint nn functional."""
16
+ from __future__ import absolute_import
17
+ from mindspore.ops.extend import max_pool2d
18
+ from mindspore.ops.functional import (
19
+ conv_transpose2d,
20
+ grid_sample
21
+ )
22
+ # 1
23
+
24
+ # 2
25
+
26
+ # 3
27
+
28
+ # 4
29
+
30
+ # 5
31
+ from mindspore.ops.functional import pad_ext as pad
32
+ # 6
33
+
34
+ # 7
35
+
36
+ # 8
37
+ from mindspore.ops.functional import layer_norm
38
+ # 9
39
+ from mindspore.ops.function.nn_func import interpolate_ext as interpolate
40
+ # 10
41
+
42
+ # 11
43
+ from mindspore.ops.functional import relu
44
+ # 12
45
+
46
+ # 13
47
+
48
+ # 14
49
+ from mindspore.ops.function.nn_func import dropout_ext as dropout
50
+ # 15
51
+
52
+ # 16
53
+
54
+ # 17
55
+
56
+ # 18
57
+
58
+ # 19
59
+
60
+ # 20
61
+
62
+ # 21
63
+
64
+ # 22
65
+
66
+ # 23
67
+
68
+ # 24
69
+
70
+ # 25
71
+
72
+ # 26
73
+
74
+ # 27
75
+
76
+ # 28
77
+
78
+ # 29
79
+
80
+ # 30
81
+
82
+ # 31
83
+
84
+ # 32
85
+
86
+ # 33
87
+
88
+ # 34
89
+
90
+ # 35
91
+
92
+ # 36
93
+ from mindspore.ops.functional import gelu
94
+ # 37
95
+
96
+ # 38
97
+
98
+ # 39
99
+ from mindspore.ops.functional import group_norm
100
+ # 40
101
+
102
+ # 41
103
+
104
+ # 42
105
+
106
+ # 43
107
+
108
+ # 44
109
+
110
+ # 45
111
+
112
+ # 46
113
+ from mindspore.ops.functional import silu
114
+ # 47
115
+
116
+ # 48
117
+
118
+ # 49
119
+ from mindspore.ops.functional import sigmoid
120
+ # 50
121
+
122
+ # 51
123
+
124
+ # 52
125
+ from mindspore.ops.functional import embedding
126
+ # 53
127
+
128
+ # 54
129
+
130
+ # 55
131
+
132
+ # 56
133
+
134
+ # 57
135
+
136
+ # 58
137
+
138
+ # 59
139
+
140
+ # 60
141
+
142
+ # 61
143
+
144
+ # 62
145
+
146
+ # 63
147
+
148
+ # 64
149
+
150
+ # 65
151
+
152
+ # 66
153
+
154
+ # 67
155
+
156
+ # 68
157
+
158
+ # 69
159
+
160
+ # 70
161
+
162
+ # 71
163
+
164
+ # 72
165
+
166
+ # 73
167
+
168
+ # 74
169
+
170
+ # 75
171
+
172
+ # 76
173
+
174
+ # 77
175
+
176
+ # 78
177
+
178
+ # 79
179
+
180
+ # 80
181
+
182
+ # 81
183
+
184
+ # 82
185
+
186
+ # 83
187
+
188
+ # 84
189
+
190
+ # 85
191
+
192
+ # 86
193
+
194
+ # 87
195
+
196
+ # 88
197
+
198
+ # 89
199
+
200
+ # 90
201
+ from mindspore.ops.function.nn_func import avg_pool2d_ext as avg_pool2d
202
+ # 91
203
+
204
+ # 92
205
+ from mindspore.ops.extend import leaky_relu_ext as leaky_relu
206
+ # 93
207
+ from mindspore.ops.function.nn_func import softplus_ext as softplus
208
+ # 94
209
+ from mindspore.ops.function.math_func import tanh
210
+ # 95
211
+
212
+ # 96
213
+
214
+ # 97
215
+
216
+ # 98
217
+
218
+ # 99
219
+
220
+ # 100
221
+
222
+ __all__ = [
223
+ 'conv_transpose2d',
224
+ 'max_pool2d',
225
+ # 1
226
+
227
+ # 2
228
+
229
+ # 3
230
+
231
+ # 4
232
+
233
+ # 5
234
+ 'pad',
235
+ # 6
236
+
237
+ # 7
238
+
239
+ # 8
240
+ 'layer_norm',
241
+ # 9
242
+ 'interpolate',
243
+ # 10
244
+
245
+ # 11
246
+ 'relu',
247
+ # 12
248
+
249
+ # 13
250
+
251
+ # 14
252
+ 'dropout',
253
+ # 15
254
+
255
+ # 16
256
+
257
+ # 17
258
+
259
+ # 18
260
+
261
+ # 19
262
+
263
+ # 20
264
+
265
+ # 21
266
+
267
+ # 22
268
+
269
+ # 23
270
+
271
+ # 24
272
+
273
+ # 25
274
+
275
+ # 26
276
+
277
+ # 27
278
+
279
+ # 28
280
+
281
+ # 29
282
+
283
+ # 30
284
+
285
+ # 31
286
+
287
+ # 32
288
+
289
+ # 33
290
+
291
+ # 34
292
+
293
+ # 35
294
+
295
+ # 36
296
+ 'gelu',
297
+ # 37
298
+
299
+ # 38
300
+
301
+ # 39
302
+ 'group_norm',
303
+ # 40
304
+
305
+ # 41
306
+
307
+ # 42
308
+
309
+ # 43
310
+
311
+ # 44
312
+
313
+ # 45
314
+
315
+ # 46
316
+ 'silu',
317
+ # 47
318
+
319
+ # 48
320
+
321
+ # 49
322
+ 'sigmoid',
323
+ # 50
324
+
325
+ # 51
326
+
327
+ # 52
328
+ 'embedding',
329
+ # 53
330
+
331
+ # 54
332
+
333
+ # 55
334
+
335
+ # 56
336
+
337
+ # 57
338
+
339
+ # 58
340
+
341
+ # 59
342
+
343
+ # 60
344
+
345
+ # 61
346
+
347
+ # 62
348
+
349
+ # 63
350
+
351
+ # 64
352
+
353
+ # 65
354
+
355
+ # 66
356
+
357
+ # 67
358
+
359
+ # 68
360
+
361
+ # 69
362
+
363
+ # 70
364
+
365
+ # 71
366
+
367
+ # 72
368
+
369
+ # 73
370
+
371
+ # 74
372
+
373
+ # 75
374
+
375
+ # 76
376
+
377
+ # 77
378
+
379
+ # 78
380
+
381
+ # 79
382
+
383
+ # 80
384
+
385
+ # 81
386
+
387
+ # 82
388
+
389
+ # 83
390
+
391
+ # 84
392
+
393
+ # 85
394
+
395
+ # 86
396
+
397
+ # 87
398
+
399
+ # 88
400
+
401
+ # 89
402
+
403
+ # 90
404
+ 'avg_pool2d',
405
+ # 91
406
+ 'grid_sample',
407
+ # 92
408
+ 'leaky_relu',
409
+ # 93
410
+ 'softplus',
411
+ # 94
412
+ 'tanh',
413
+ # 95
414
+
415
+ # 96
416
+
417
+ # 97
418
+
419
+ # 98
420
+
421
+ # 99
422
+
423
+ # 100
424
+ ]
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """
16
+ Optimizer.
17
+
18
+ Provide common optimizers for training, such as AdamW.
19
+ The optimizer is used to calculate and update the gradients.
20
+ """
21
+ from __future__ import absolute_import
22
+ from mindspore.mint.optim.adamw import AdamW
23
+
24
+ __all__ = ['AdamW']
@@ -0,0 +1,186 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adamw"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ from mindspore.common import dtype as mstype
22
+ from mindspore.ops import auto_generate as gen
23
+ from mindspore.experimental.optim.optimizer import Optimizer
24
+ from mindspore import _checkparam as validator
25
+
26
+ _optim_adamw_opt = C.MultitypeFuncGraph("optim_adamw_opt")
27
+ hyper_map = C.HyperMap()
28
+
29
+
30
+ @_optim_adamw_opt.register("Function", "Float", "Float", "Float", "Float", "Float", "Tensor", "Bool", "Bool", "Tensor",
31
+ "Tensor", "Tensor", "Tensor", "Tensor")
32
+ def _run_optim_adamw_opt(opt, beta1, beta2, lr, eps, weight_decay, step, amsgrad, maximize, parameters, grads, exp_avg,
33
+ exp_avg_sq, max_exp_avg_sq):
34
+ """Apply adamw optimizer to the weight parameter."""
35
+ success = True
36
+ opt(parameters, exp_avg, exp_avg_sq, max_exp_avg_sq, P.Cast()(grads, F.dtype(parameters)), step, lr, beta1, beta2,
37
+ weight_decay, eps, amsgrad, maximize)
38
+ return success
39
+
40
+
41
+ def _check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, prim_name):
42
+ """Check the type of inputs."""
43
+ validator.check_value_type('betas', betas, [tuple], prim_name)
44
+ validator.check("betas size", len(betas), "", [2], validator.IN, prim_name)
45
+ validator.check_value_type("betas[0]", betas[0], [float], prim_name)
46
+ validator.check_value_type("betas[1]", betas[1], [float], prim_name)
47
+ validator.check_value_type("eps", eps, [float], prim_name)
48
+ validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
49
+ validator.check_value_type("lr", lr, [float], prim_name)
50
+ validator.check_value_type("amsgrad", amsgrad, [bool], prim_name)
51
+ validator.check_value_type("maximize", maximize, [bool], prim_name)
52
+
53
+
54
+ class AdamW(Optimizer):
55
+ r"""
56
+ Implements Adam Weight Decay algorithm.
57
+
58
+ .. math::
59
+ \begin{aligned}
60
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
61
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
62
+ \: \epsilon \text{ (epsilon)} \\
63
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
64
+ \: \textit{maximize} \\
65
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
66
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
67
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
68
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
69
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
70
+ &\hspace{5mm}\textbf{else} \\
71
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
72
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
73
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
74
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
75
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
76
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
77
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
78
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
79
+ \widehat{v_t}) \\
80
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
81
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
82
+ &\hspace{5mm}\textbf{else} \\
83
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
84
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
85
+ &\bf{return} \: \theta_t \\[-1.ex]
86
+ \end{aligned}
87
+
88
+ .. warning::
89
+ This is an experimental optimizer API that is subject to change.
90
+ This module must be used with lr scheduler module in `LRScheduler Class
91
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
92
+
93
+ Args:
94
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
95
+ parameter groups
96
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
97
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
98
+ Default: ``(0.9, 0.999)``.
99
+ eps (float, optional): term added to the denominator to improve
100
+ numerical stability. Default: ``1e-8``.
101
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0.``.
102
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
103
+
104
+ Keyword Args:
105
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
106
+ Default: ``False``.
107
+
108
+ Inputs:
109
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
110
+
111
+ Raises:
112
+ ValueError: If the learning rate is not int, float or Tensor.
113
+ ValueError: If the learning rate is less than 0.
114
+ ValueError: If the `eps` is less than 0.0.
115
+ ValueError: If the `betas` not in the range of [0, 1).
116
+ ValueError: If the `weight_decay` is less than 0.
117
+
118
+ Supported Platforms:
119
+ ``Ascend`` ``GPU`` ``CPU``
120
+
121
+ Examples:
122
+ >>> import mindspore
123
+ >>> from mindspore import nn
124
+ >>> from mindspore.mint import optim
125
+ >>> # Define the network structure of LeNet5. Refer to
126
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
127
+ >>> net = LeNet5()
128
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
129
+ >>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
130
+ >>> def forward_fn(data, label):
131
+ ... logits = net(data)
132
+ ... loss = loss_fn(logits, label)
133
+ ... return loss, logits
134
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
135
+ >>> def train_step(data, label):
136
+ ... (loss, _), grads = grad_fn(data, label)
137
+ ... optimizer(grads)
138
+ ... return loss
139
+ """
140
+
141
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
142
+ weight_decay=1e-2, amsgrad=False, *, maximize=False):
143
+ _check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, self.cls_name)
144
+ if lr < 0.0:
145
+ raise ValueError("Invalid learning rate: {}".format(lr))
146
+ if eps < 0.0:
147
+ raise ValueError("Invalid epsilon value: {}".format(eps))
148
+ if not 0.0 <= betas[0] < 1.0:
149
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
150
+ if not 0.0 <= betas[1] < 1.0:
151
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
152
+ if weight_decay < 0.0:
153
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
154
+
155
+ defaults = dict(lr=lr, betas=betas, eps=eps,
156
+ weight_decay=weight_decay, amsgrad=amsgrad,
157
+ maximize=maximize)
158
+ super(AdamW, self).__init__(params, defaults)
159
+
160
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
161
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
162
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
163
+ self.state_step = Parameter(Tensor([0], mstype.float32), "state_step")
164
+ self.increase_tensor = Tensor(1, mstype.float32)
165
+ self.assignadd = P.AssignAdd()
166
+ self.op_cast = P.Cast()
167
+ self.adamw_opt = gen.AdamWeightDecayExt()
168
+
169
+ def construct(self, gradients):
170
+ self.assignadd(self.state_step, self.increase_tensor)
171
+ for group_id, group in enumerate(self.param_groups):
172
+ beta1, beta2 = group['betas']
173
+ maximize = group.get("maximize")
174
+ start_id = self.group_start_id[group_id]
175
+ end_id = self.group_start_id[group_id + 1]
176
+ lr = self.lrs[group_id]
177
+ if isinstance(group.get("lr"), float):
178
+ lr = self.op_cast(group.get("lr"), mstype.float32)
179
+ grads = tuple([grad if not maximize else F.neg(grad) for grad in gradients[start_id: end_id]])
180
+
181
+ self.hyper_map(F.partial(_optim_adamw_opt, self.adamw_opt, beta1, beta2, float(lr),
182
+ group.get("eps"), group.get("weight_decay"), self.state_step,
183
+ group.get("amsgrad"), maximize),
184
+ self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
185
+ self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
186
+ return True
@@ -16,6 +16,7 @@
16
16
  mindspore.multiprocessing is a wrapper around the native `multiprocessing` module.
17
17
  Some methods are overrode to support fork-based multiprocess.
18
18
  """
19
+ import types
19
20
  import signal
20
21
  import multiprocessing as mp
21
22
  from multiprocessing import *
@@ -64,5 +65,8 @@ class Pool(mp.pool.Pool): # pylint: disable=function-redefined, abstract-method
64
65
  """
65
66
  def Process(self, *args, **kwds):
66
67
  if self._ctx.get_start_method() == "fork":
68
+ # Process() becomes a staticmethod function of Pool with first argument 'ctx' in python 3.8.0 and later
69
+ if isinstance(super().Process, types.FunctionType):
70
+ args = args[1:]
67
71
  return _MsProcess(*args, **kwds)
68
72
  return super().Process(*args, **kwds)
mindspore/nn/__init__.py CHANGED
@@ -21,6 +21,7 @@ from __future__ import absolute_import
21
21
 
22
22
  from mindspore.nn import layer, loss, optim, wrap, grad, metrics, probability, sparse, dynamic_lr, reinforcement
23
23
  from mindspore.nn.learning_rate_schedule import *
24
+ from mindspore.nn.generator import *
24
25
  from mindspore.nn.dynamic_lr import *
25
26
  from mindspore.nn.cell import Cell, GraphCell
26
27
  from mindspore.nn.layer import *
@@ -31,6 +32,7 @@ from mindspore.nn.wrap import *
31
32
  from mindspore.nn.grad import Jvp, Vjp
32
33
  from mindspore.nn.sparse import *
33
34
  from mindspore.nn.reinforcement import *
35
+ from mindspore.nn import extend
34
36
 
35
37
  __all__ = ["Cell", "GraphCell"]
36
38
  __all__.extend(layer.__all__)
@@ -43,5 +45,6 @@ __all__.extend(sparse.__all__)
43
45
  __all__.extend(learning_rate_schedule.__all__)
44
46
  __all__.extend(dynamic_lr.__all__)
45
47
  __all__.extend(reinforcement.__all__)
48
+ __all__.extend(generator.__all__)
46
49
 
47
50
  __all__.sort()