mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ class _OffloadConfig:
33
33
  OFFLOAD_PARAM = "offload_param"
34
34
  OFFLOAD_PATH = "offload_path"
35
35
  OFFLOAD_CPU_SIZE = "offload_cpu_size"
36
+ OFFLOAD_CHECKPOINT = "offload_checkpoint"
36
37
  OFFLOAD_DISK_SIZE = "offload_disk_size"
37
38
  ENABLE_AIO = "enable_aio"
38
39
  AIO_BLOCK_SIZE = "aio_block_size"
@@ -84,6 +85,16 @@ class _OffloadContext:
84
85
  Validator.check_string(offload_param.lower(), ["cpu", "disk"], "offload_param", "set_offload_param")
85
86
  self._context_handle.set_offload_param(offload_param.lower())
86
87
 
88
+ def set_offload_checkpoint(self, offload_checkpoint):
89
+ """Set offload_checkpoint"""
90
+ if not isinstance(offload_checkpoint, str):
91
+ raise TypeError("For 'set_offload_checkpoint', "
92
+ "the argument 'offload_checkpoint' must be str, but got the type : {}."
93
+ .format(type(offload_checkpoint)))
94
+ Validator.check_string(offload_checkpoint.lower(), ["cpu", "disk"], "offload_checkpoint",
95
+ "set_offload_checkpoint")
96
+ self._context_handle.set_offload_checkpoint(offload_checkpoint.lower())
97
+
87
98
  def set_offload_path(self, offload_path):
88
99
  """Set offload_path"""
89
100
  if not isinstance(offload_path, str):
@@ -194,7 +205,8 @@ class _OffloadContext:
194
205
  _OffloadConfig.HBM_RATIO, _OffloadConfig.OFFLOAD_CPU_SIZE,
195
206
  _OffloadConfig.OFFLOAD_DISK_SIZE, _OffloadConfig.ENABLE_AIO,
196
207
  _OffloadConfig.AIO_BLOCK_SIZE, _OffloadConfig.AIO_QUEUE_DEPTH,
197
- _OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD]:
208
+ _OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD,
209
+ _OffloadConfig.OFFLOAD_CHECKPOINT]:
198
210
  unknown_config.append(config_name)
199
211
 
200
212
  if unknown_config:
@@ -220,7 +232,8 @@ class _OffloadContext:
220
232
  _OffloadConfig.AUTO_OFFLOAD: self._context_handle.auto_offload(),
221
233
  _OffloadConfig.HOST_MEM_BLOCk_SIZE: self._context_handle.host_mem_block_size(),
222
234
  _OffloadConfig.CPU_RATIO: self._context_handle.cpu_ratio(),
223
- _OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio()
235
+ _OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio(),
236
+ _OffloadConfig.OFFLOAD_CHECKPOINT: self._context_handle.offload_checkpoint()
224
237
  }
225
238
  return offload_config
226
239
 
@@ -257,5 +270,6 @@ _set_offload_context_func_map = {
257
270
  _OffloadConfig.AUTO_OFFLOAD: offload_context().set_auto_offload,
258
271
  _OffloadConfig.HOST_MEM_BLOCk_SIZE: offload_context().set_host_mem_block_size,
259
272
  _OffloadConfig.CPU_RATIO: offload_context().set_cpu_ratio,
260
- _OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio
273
+ _OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio,
274
+ _OffloadConfig.OFFLOAD_CHECKPOINT: offload_context().set_offload_checkpoint
261
275
  }
@@ -330,12 +330,13 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
330
330
  device_list = list(range(0, np.prod(from_tensor_layout[0])))
331
331
  param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id)
332
332
  param_rank_list_new = [rank % from_device_num for rank in param_rank_list]
333
- param_rank_list_new = set(param_rank_list_new)
334
- result_list.update(param_rank_list_new)
333
+ param_rank_set_new = set(param_rank_list_new)
334
+ result_list.update(param_rank_set_new)
335
335
  return list(result_list)
336
336
 
337
337
 
338
- def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list, dst_strategy_list):
338
+ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
339
+ dst_strategy_list, param_type_dict):
339
340
  """
340
341
  Transform model parallel dimension for distributed checkpoint files.
341
342
  """
@@ -397,15 +398,21 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
397
398
  transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
398
399
  requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
399
400
  layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
400
- transform_param_dict[param_name] = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
401
+ transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
402
+ if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
403
+ transform_para.set_dtype(ms.bfloat16)
404
+ transform_param_dict[param_name] = transform_para
401
405
 
402
406
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
403
407
  for param_name, _ in param_total_dict.items():
404
408
  if param_name not in transform_param_dict:
405
- transform_param_dict[param_name] = ms.Parameter(
409
+ transform_para = ms.Parameter(
406
410
  ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
407
411
  param_attr_dict[param_name][rank_id % device_num][0],
408
412
  param_attr_dict[param_name][rank_id % device_num][1])
413
+ if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
414
+ transform_para.set_dtype(ms.bfloat16)
415
+ transform_param_dict[param_name] = transform_para
409
416
 
410
417
  transform_param_list = [{"name": param_name, "data": param_data}
411
418
  for param_name, param_data in transform_param_dict.items()]
@@ -228,3 +228,15 @@ def _enable_distributed_mindrt():
228
228
  This method is used to distinguish from old distributed training mode.
229
229
  '''
230
230
  return ps_context().enable_distributed_mindrt()
231
+
232
+
233
+ def _set_checkpoint_load_status(status):
234
+ return ps_context().set_checkpoint_load_status(status)
235
+
236
+
237
+ def _store_warm_up_ptr_by_tensor(param_key, tensor):
238
+ return ps_context().store_warm_up_ptr_by_tensor(param_key, tensor)
239
+
240
+
241
+ def _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor):
242
+ return ps_context().store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
@@ -17,7 +17,7 @@ from __future__ import division
17
17
  from __future__ import absolute_import
18
18
 
19
19
  import numpy as np
20
-
20
+ from mindspore.common import dtype as mstype
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore.communication.management import get_rank, get_group_size
23
23
  from mindspore._c_expression import TensorTransform
@@ -41,7 +41,7 @@ def _get_tensor_strategy(dev_mat, tensor_map):
41
41
  if dim == -1:
42
42
  tensor_strategy.append(1)
43
43
  else:
44
- tensor_strategy.append(dev_mat[-dim-1])
44
+ tensor_strategy.append(dev_mat[-dim - 1])
45
45
  return tensor_strategy
46
46
 
47
47
 
@@ -198,7 +198,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
198
198
  return tensor_slice_index
199
199
 
200
200
 
201
- def _load_tensor(tensor, dev_mat, tensor_map):
201
+ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
202
202
  """
203
203
  Get the tensor slice of the local device by the device matrix and the tensor map
204
204
 
@@ -216,16 +216,21 @@ def _load_tensor(tensor, dev_mat, tensor_map):
216
216
  >>> tensor_map = [1, -1]
217
217
  >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
218
218
  """
219
- rank = get_rank()
219
+ if rank_id == -1:
220
+ rank = get_rank()
221
+ else:
222
+ rank = rank_id
220
223
  tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
221
224
  tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
225
+ if tensor.dtype == mstype.bfloat16:
226
+ tensor = tensor.float()
222
227
  np_tensor = tensor.asnumpy()
223
228
  np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
224
229
  np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
225
230
  return np_tensor_slice
226
231
 
227
232
 
228
- def _load_tensor_by_layout(tensor, layout):
233
+ def _load_tensor_by_layout(tensor, layout, rank_id):
229
234
  """
230
235
  Load tensor by layout.
231
236
 
@@ -246,19 +251,19 @@ def _load_tensor_by_layout(tensor, layout):
246
251
  raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
247
252
  dev_mat = layout[0]
248
253
  tensor_map = layout[1]
249
- if len(tensor_map) == 0:
254
+ if not tensor_map:
250
255
  return tensor
251
256
  uniform_split = layout[4]
252
257
  group = layout[5]
253
258
  if uniform_split == 0:
254
259
  raise RuntimeError("The load tensor only support uniform split now")
255
- tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
260
+ tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, rank_id)
256
261
  if group:
257
262
  # get a totally shard tensor slice for parallel optimizer
258
263
  rank = get_rank(group)
259
264
  size = get_group_size(group)
260
265
  tensor_slice = np.split(tensor_slice, size)[rank]
261
- return Tensor(tensor_slice)
266
+ return Tensor(tensor_slice, tensor.dtype)
262
267
 
263
268
 
264
269
  def _reshape_param_data(param_data, dev_mat, tensor_map):
@@ -315,7 +320,6 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
315
320
  return Tensor(tensor_slices_new[0])
316
321
 
317
322
 
318
-
319
323
  def _extract_layout_item(layout_item):
320
324
  dev_matrix = layout_item[0]
321
325
  tensor_map = layout_item[1]
@@ -541,6 +545,7 @@ def _check_operator(operator):
541
545
 
542
546
  def _apply_operator(operator_name):
543
547
  """apply transform operator"""
548
+
544
549
  def _apply_reshape_operator(numpy_data, reshape_op):
545
550
  """
546
551
  Apply reshape operator.
@@ -597,8 +602,8 @@ def _apply_operator(operator_name):
597
602
  raise ValueError("The slice operator information is wrong.")
598
603
  shape_size = len(slice_op[1]) // 3
599
604
  begin = slice_op[1][:shape_size]
600
- end = slice_op[1][shape_size:shape_size*2]
601
- stride = slice_op[1][shape_size*2:]
605
+ end = slice_op[1][shape_size:shape_size * 2]
606
+ stride = slice_op[1][shape_size * 2:]
602
607
  slice_index = []
603
608
  for begin_i, end_i, strides_i in zip(begin, end, stride):
604
609
  s = slice(begin_i, end_i, strides_i)
@@ -637,8 +642,8 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
637
642
  for i in range(len(tensor_slices[0][0])):
638
643
  tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
639
644
  for j in range(1, device_count):
640
- tensor_slices_new = np.concatenate((tensor_slices_new,\
641
- np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
645
+ tensor_slices_new = np.concatenate((tensor_slices_new, \
646
+ np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
642
647
  tensor_slices_col.append(tensor_slices_new)
643
648
  new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
644
649
  for i in range(1, len(tensor_slices_col)):
@@ -424,9 +424,11 @@ class _Linear(Cell):
424
424
  self.out_channels = out_channels
425
425
  if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
426
426
  raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
427
- if isinstance(weight_init, Tensor) and (weight_init.ndim != 2 or weight_init.shape[0] != out_channels or
428
- weight_init.shape[1] != in_channels):
429
- raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
427
+
428
+ if isinstance(weight_init, Tensor):
429
+ if weight_init.ndim != 2 or weight_init.shape[0] != out_channels \
430
+ or weight_init.shape[1] != in_channels:
431
+ raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
430
432
  weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
431
433
  self.expert_num = expert_num
432
434
  self.outer_batch = outer_batch
@@ -139,6 +139,7 @@ class _NLLLoss(Cell):
139
139
  self.add = P.Add().shard(((dp, mp), ()))
140
140
 
141
141
  def construct(self, softmax_result, one_hot_label):
142
+ """The forward of _NLLLoss"""
142
143
  log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
143
144
  loss = self.mul(log_softmax_result, one_hot_label)
144
145
  loss_unsum = self.neg(loss)
@@ -273,7 +273,7 @@ class MoE(Cell):
273
273
  if self.group_wise_a2a:
274
274
  # If capacity can't div by mp, pad for mp shard.
275
275
  if capacity % self.mp != 0:
276
- pad_size = self.mp-(capacity % self.mp)
276
+ pad_size = self.mp - (capacity % self.mp)
277
277
  if pad_size != 0:
278
278
  capacity += pad_size
279
279
  pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
@@ -330,7 +330,7 @@ class MoE(Cell):
330
330
  # Pad capacity for comp_comm_parallel_degree split.
331
331
  pad_size = 0
332
332
  if capacity % self.comp_comm_parallel_degree != 0:
333
- pad_size = self.comp_comm_parallel_degree-(capacity % self.comp_comm_parallel_degree)
333
+ pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
334
334
  capacity += pad_size
335
335
  pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
336
336
  (self.expert_dim, self.dp_group, pad_size, self.hidden_size),
@@ -147,10 +147,11 @@ class _PipeLineConfig(_Config):
147
147
  >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
148
148
  """
149
149
 
150
- def __init__(self, pipeline_stage=1, micro_batch_num=1):
150
+ def __init__(self, pipeline_stage=1, micro_batch_num=1, pipeline_segment=1):
151
151
  Validator.check_positive_int(pipeline_stage, "pipeline_stage")
152
152
  Validator.check_positive_int(micro_batch_num, "micro_batch_num")
153
153
  self.pipeline_stage = pipeline_stage
154
+ self.pipeline_segment = pipeline_segment
154
155
  self.micro_batch_num = micro_batch_num
155
156
 
156
157
  @property
@@ -163,6 +164,16 @@ class _PipeLineConfig(_Config):
163
164
  self._pipeline_stage = value
164
165
  context.set_auto_parallel_context(pipeline_stages=value)
165
166
 
167
+ @property
168
+ def pipeline_segment(self):
169
+ return self._pipeline_segment
170
+
171
+ @pipeline_segment.setter
172
+ def pipeline_segment(self, value):
173
+ Validator.check_positive_int(value, "pipeline_segment")
174
+ self._pipeline_segment = value
175
+ context.set_auto_parallel_context(pipeline_segments=value)
176
+
166
177
  @property
167
178
  def micro_batch_num(self):
168
179
  return self._micro_batch_num
@@ -226,7 +226,8 @@ class TransformerOpParallelConfig(_Config):
226
226
  >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
227
227
  """
228
228
 
229
- def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1,
229
+ def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, pipeline_segment=1,
230
+ micro_batch_num=1,
230
231
  recompute=default_transformer_recompute_config,
231
232
  optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
232
233
  self.recompute = recompute
@@ -234,7 +235,8 @@ class TransformerOpParallelConfig(_Config):
234
235
  self.gradient_aggregation_group = gradient_aggregation_group
235
236
  self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
236
237
  vocab_emb_dp=vocab_emb_dp)
237
- self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
238
+ self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num,
239
+ pipeline_segment=pipeline_segment)
238
240
  self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
239
241
  expert_parallel=expert_parallel)
240
242
 
@@ -309,6 +311,14 @@ class TransformerOpParallelConfig(_Config):
309
311
  def pipeline_stage(self, value):
310
312
  self._pp_config.pipeline_stage = value
311
313
 
314
+ @property
315
+ def pipeline_segment(self):
316
+ return self._pp_config.pipeline_segment
317
+
318
+ @pipeline_segment.setter
319
+ def pipeline_segment(self, value):
320
+ self._pp_config.pipeline_segment = value
321
+
312
322
  @property
313
323
  def optimizer_shard(self):
314
324
  return self._optimizer_shard
@@ -429,6 +439,7 @@ class FeedForward(Cell):
429
439
  >>> print(output.shape)
430
440
  (2, 20, 15)
431
441
  """
442
+
432
443
  @_LogActionOnce(logger=logger, key='FeedForward',
433
444
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
434
445
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -622,6 +633,7 @@ class AttentionMask(Cell):
622
633
  [1. 1. 1. 0]
623
634
  [0. 0. 0. 0]]]
624
635
  """
636
+
625
637
  @_LogActionOnce(logger=logger, key='AttentionMask',
626
638
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
627
639
  @_args_type_validator_check(seq_length=Validator.check_positive_int,
@@ -710,6 +722,7 @@ class VocabEmbedding(Cell):
710
722
  >>> print(table.shape)
711
723
  (30, 30)
712
724
  """
725
+
713
726
  @_LogActionOnce(logger=logger, key='VocabEmbedding',
714
727
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
715
728
  @_args_type_validator_check(vocab_size=Validator.check_positive_int,
@@ -866,6 +879,7 @@ class MultiHeadAttention(Cell):
866
879
  >>> print(past[1].shape)
867
880
  (2, 3, 20, 5)
868
881
  """
882
+
869
883
  @_LogActionOnce(logger=logger, key='MultiHeadAttention',
870
884
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
871
885
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -1203,7 +1217,8 @@ class MultiHeadAttention(Cell):
1203
1217
  def _get_batch_size_from_query(self, query):
1204
1218
  r"""Get the batch size from query tensor"""
1205
1219
  # For the incremental prediction, the seq length for the input is 1.
1206
- if len(F.shape(query)) == 2 and ((self.use_past and self.is_first_iteration) or (not self.use_past)):
1220
+ incr_infer = self.use_past and self.is_first_iteration
1221
+ if len(F.shape(query)) == 2 and ((incr_infer) or (not self.use_past)):
1207
1222
  return F.shape(query)[0] // self.src_seq_length
1208
1223
  return F.shape(query)[0]
1209
1224
 
@@ -1459,6 +1474,7 @@ class TransformerEncoderLayer(Cell):
1459
1474
  >>> print(past[1].shape)
1460
1475
  (2, 2, 16, 4)
1461
1476
  """
1477
+
1462
1478
  @_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
1463
1479
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1464
1480
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -1848,6 +1864,7 @@ class TransformerDecoderLayer(Cell):
1848
1864
  >>> print(past[3].shape)
1849
1865
  (2, 2, 20, 32)
1850
1866
  """
1867
+
1851
1868
  @_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
1852
1869
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1853
1870
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -2379,6 +2396,7 @@ class TransformerEncoder(Cell):
2379
2396
  >>> print(past[0][1].shape)
2380
2397
  (2, 2, 16, 4)
2381
2398
  """
2399
+
2382
2400
  @_LogActionOnce(logger=logger, key='TransformerEncoder',
2383
2401
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2384
2402
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -2613,6 +2631,7 @@ class TransformerDecoder(Cell):
2613
2631
  >>> print(past[0][3].shape)
2614
2632
  (2, 2, 20, 32)
2615
2633
  """
2634
+
2616
2635
  @_LogActionOnce(logger=logger, key='TransformerDecoder',
2617
2636
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2618
2637
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -2882,6 +2901,7 @@ class Transformer(Cell):
2882
2901
  >>> print(de_past[0][3].shape)
2883
2902
  (2, 2, 20, 32)
2884
2903
  """
2904
+
2885
2905
  @_LogActionOnce(logger=logger, key='Transformer',
2886
2906
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2887
2907
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -100,13 +100,14 @@ def _slice_parameter(parameter, phase, layout):
100
100
  parameter.sliced = True
101
101
  return
102
102
  if not parameter.sliced:
103
- new_tensor = _load_tensor_by_layout(parameter, layout)
103
+ rank = get_rank()
104
+ new_tensor = _load_tensor_by_layout(parameter, layout, rank)
104
105
  parameter.set_data(new_tensor, True)
105
106
 
106
107
 
107
- def _slice_tensor(tensor, layout):
108
+ def _slice_tensor(tensor, layout, rank_id):
108
109
  """Slice python tensor obj according to the layout."""
109
- new_tensor = _load_tensor_by_layout(tensor, layout)
110
+ new_tensor = _load_tensor_by_layout(tensor, layout, rank_id)
110
111
  return new_tensor
111
112
 
112
113
 
@@ -136,14 +137,17 @@ def _to_full_shapes(shapes, device_num):
136
137
  "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
137
138
  new_shape = ()
138
139
  for i, item in enumerate(shape):
139
- new_shape += (item * dataset_strategy[index][i],)
140
+ if item > 0:
141
+ new_shape += (item * dataset_strategy[index][i],) # static shape
142
+ else:
143
+ new_shape += (item,) # dynamic shape
140
144
  new_shapes.append(new_shape)
141
145
  return new_shapes
142
146
  for shape in shapes:
143
147
  new_shape = ()
144
148
  for i, item in enumerate(shape):
145
- if i == 0:
146
- new_shape += (item * device_num,)
149
+ if i == 0 and item > 0:
150
+ new_shape += (item * device_num,) # only for static shape
147
151
  else:
148
152
  new_shape += (item,)
149
153
  new_shapes.append(new_shape)
@@ -201,7 +205,7 @@ def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
201
205
  slice_index += (s,)
202
206
  new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
203
207
  new_tensor_numpy[slice_index] = data.asnumpy()
204
- new_tensor = Tensor(new_tensor_numpy)
208
+ new_tensor = Tensor(new_tensor_numpy, dtype=type_)
205
209
  lst.append(new_tensor)
206
210
  if scaling_sens:
207
211
  lst.append(Tensor(scaling_sens, mstype.float32))
@@ -229,7 +229,7 @@ def set_algo_parameters(**kwargs):
229
229
  """
230
230
  Set parameters in the algorithm for parallel strategy searching. See a typical use in
231
231
  `test_auto_parallel_resnet.py
232
- <https://gitee.com/mindspore/mindspore/blob/r2.1/tests/ut/python/parallel/test_auto_parallel_resnet.py>`_.
232
+ <https://gitee.com/mindspore/mindspore/blob/r2.2/tests/ut/python/parallel/test_auto_parallel_resnet.py>`_.
233
233
 
234
234
  Note:
235
235
  The attribute name is required. This interface works ONLY in AUTO_PARALLEL mode.
@@ -239,10 +239,10 @@ def set_algo_parameters(**kwargs):
239
239
  Default: ``True`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
240
240
  included in ReLU's candidate strategies, because strategy (4, 1) only utilizes 4 devices.
241
241
  elementwise_op_strategy_follow (bool): Whether the elementwise operator has the consistent strategies as its
242
- subsequent operators. Default: ``False`` . For the example of ReLU followed by Add, where ReLU is
243
- elementwise operator, if this flag is set ``True`` , then the searched strategy by the algorithm
244
- guarantees that strategies of these two operators are consistent, e.g., ReLU's strategy (8, 1) and Add's
245
- strategy ((8, 1), (8, 1)).
242
+ subsequent operators. Elementwise operators refer to operators that operate on input element by element,
243
+ such as Add, ReLU, etc. Default: ``False`` . For the example of ReLU followed by Add, if this flag is set
244
+ ``True`` , then the searched strategy by the algorithm guarantees that strategies of these two operators
245
+ are consistent, e.g., ReLU's strategy (8, 1) and Add's strategy ((8, 1), (8, 1)).
246
246
  enable_algo_approxi (bool): Whether to enable the approximation in the algorithms. Default: ``False`` . Due to
247
247
  large solution space in searching parallel strategy for large DNN model, the algorithm takes fairly long
248
248
  time in this case. To mitigate it, if this flag is set ``True`` , an approximation is made to discard some
@@ -261,8 +261,87 @@ def set_algo_parameters(**kwargs):
261
261
  ValueError: If context keyword is not recognized.
262
262
 
263
263
  Examples:
264
+ .. note::
265
+ Before running the following examples, you need to configure the communication environment variables.
266
+
267
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
268
+ Please see the `rank table startup
269
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
270
+ for more details.
271
+
272
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
273
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
274
+
275
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
276
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
277
+
278
+ >>> import numpy as np
264
279
  >>> import mindspore as ms
280
+ >>> import mindspore.dataset as ds
281
+ >>> from mindspore import nn, ops, train
282
+ >>> from mindspore.communication import init
283
+ >>> from mindspore.common.initializer import initializer
284
+ >>>
285
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
286
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
287
+ >>> search_mode="sharding_propagation")
288
+ >>> init()
289
+ >>> ms.set_algo_parameters(fully_use_devices=True)
265
290
  >>> ms.set_algo_parameters(elementwise_op_strategy_follow=True)
291
+ >>> ms.set_algo_parameters(enable_algo_approxi=True)
292
+ >>> ms.set_algo_parameters(algo_approxi_epsilon=0.2)
293
+ >>> ms.set_algo_parameters(tensor_slice_align_enable=True)
294
+ >>> ms.set_algo_parameters(tensor_slice_align_size=8)
295
+ >>>
296
+ >>> # Define the network structure.
297
+ >>> class Dense(nn.Cell):
298
+ ... def __init__(self, in_channels, out_channels):
299
+ ... super().__init__()
300
+ ... self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
301
+ ... self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
302
+ ... self.matmul = ops.MatMul()
303
+ ... self.add = ops.Add()
304
+ ...
305
+ ... def construct(self, x):
306
+ ... x = self.matmul(x, self.weight)
307
+ ... x = self.add(x, self.bias)
308
+ ... return x
309
+ >>>
310
+ >>> class FFN(nn.Cell):
311
+ ... def __init__(self):
312
+ ... super().__init__()
313
+ ... self.flatten = ops.Flatten()
314
+ ... self.dense1 = Dense(28*28, 64)
315
+ ... self.relu = ops.ReLU()
316
+ ... self.dense2 = Dense(64, 10)
317
+ ...
318
+ ... def construct(self, x):
319
+ ... x = self.flatten(x)
320
+ ... x = self.dense1(x)
321
+ ... x = self.relu(x)
322
+ ... x = self.dense2(x)
323
+ ... return x
324
+ >>> net = FFN()
325
+ >>> net.dense1.matmul.shard(((2, 1), (1, 2)))
326
+ >>>
327
+ >>> # Create dataset.
328
+ >>> step_per_epoch = 16
329
+ >>> def get_dataset(*inputs):
330
+ ... def generate():
331
+ ... for _ in range(step_per_epoch):
332
+ ... yield inputs
333
+ ... return generate
334
+ >>>
335
+ >>> input_data = np.random.rand(1, 28, 28).astype(np.float32)
336
+ >>> label_data = np.random.rand(1).astype(np.int32)
337
+ >>> fake_dataset = get_dataset(input_data, label_data)
338
+ >>> dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
339
+ >>> # Train network.
340
+ >>> optimizer = nn.Momentum(net.trainable_params(), 1e-3, 0.1)
341
+ >>> loss_fn = nn.CrossEntropyLoss()
342
+ >>> loss_cb = train.LossMonitor()
343
+ >>> model = ms.Model(network=net, loss_fn=loss_fn, optimizer=optimizer)
344
+ >>> model.train(epoch=2, train_dataset=dataset, callbacks=[loss_cb])
266
345
  """
267
346
  for key, value in kwargs.items():
268
347
  if key not in set_algo_parameters_config_func_map:
@@ -282,6 +361,7 @@ def get_algo_parameters(attr_key):
282
361
  attr_key (str): The key of the attribute. The keys include: "fully_use_devices",
283
362
  "elementwise_op_strategy_follow", "enable_algo_approxi", "algo_approxi_epsilon",
284
363
  "tensor_slice_align_enable","tensor_slice_align_size".
364
+ See :func:`mindspore.set_algo_parameters` for more details about the meaning of the attributes.
285
365
 
286
366
  Returns:
287
367
  Return attribute value according to the key.