mindspore 2.1.0__cp37-none-any.whl → 2.2.10__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ from mindspore import log as logger
19
19
  from mindspore.ops import signature as sig
20
20
  from mindspore import _checkparam as validator
21
21
  from mindspore.common import dtype as mstype
22
- from mindspore.common._decorator import deprecated
23
22
  from mindspore.ops.primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
24
23
  from mindspore.ops.operations._pyfunc_registry import add_pyfunc
25
24
  from mindspore._c_expression import typing
@@ -738,27 +737,6 @@ class Pull(PrimitiveWithInfer):
738
737
  return mstype.float32
739
738
 
740
739
 
741
- class identity(Primitive):
742
- """
743
- The :class:`mindspore.ops.identity` interface is deprecated, please use the :class:`mindspore.nn.Identity` instead.
744
-
745
- Supported Platforms:
746
- Deprecated
747
- """
748
-
749
- # Side effect will propagated from the first argument to return value.
750
- side_effect_propagate = 1
751
-
752
- @prim_attr_register
753
- def __init__(self):
754
- """Initialize identity."""
755
- self.add_prim_attr('side_effect_propagate', 1)
756
-
757
- @deprecated('2.0', 'nn.Identity', False)
758
- def __call__(self, x):
759
- return x
760
-
761
-
762
740
  class PyInterpret(Primitive):
763
741
  r"""
764
742
  Interpret Python expression.
@@ -83,15 +83,10 @@ class TruncatedNormal(Primitive):
83
83
  Note:
84
84
  - The value of `shape` must be greater than zero. The output length can not exceed 1000000.
85
85
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
86
- and the random seed is the initial value of this random number. If the random seed is the same in two
86
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
87
87
  separate calls, the random number generated will not change.
88
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
89
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
90
- with 0 to generate random number.
91
- - Global random seed is not set, operator-level random seed is set: 0
92
- splices with the operator-level random seed to generate random number.
93
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
94
- operator-level random seed to generate random number.
88
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
89
+ to worry about which seed is more important.
95
90
 
96
91
  Args:
97
92
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -152,15 +147,10 @@ class StandardNormal(Primitive):
152
147
 
153
148
  Note:
154
149
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
155
- and the random seed is the initial value of this random number. If the random seed is the same in two
150
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
156
151
  separate calls, the random number generated will not change.
157
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
158
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
159
- with 0 to generate random number.
160
- - Global random seed is not set, operator-level random seed is set: 0
161
- splices with the operator-level random seed to generate random number.
162
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
163
- operator-level random seed to generate random number.
152
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
153
+ to worry about which seed is more important.
164
154
 
165
155
  Args:
166
156
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -208,15 +198,10 @@ class StandardLaplace(Primitive):
208
198
 
209
199
  Note:
210
200
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
211
- and the random seed is the initial value of this random number. If the random seed is the same in two
201
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
212
202
  separate calls, the random number generated will not change.
213
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
214
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
215
- with 0 to generate random number.
216
- - Global random seed is not set, operator-level random seed is set: 0
217
- splices with the operator-level random seed to generate random number.
218
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
219
- operator-level random seed to generate random number.
203
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
204
+ to worry about which seed is more important.
220
205
 
221
206
  Args:
222
207
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -266,15 +251,10 @@ class RandomGamma(Primitive):
266
251
 
267
252
  Note:
268
253
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
269
- and the random seed is the initial value of this random number. If the random seed is the same in two
254
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
270
255
  separate calls, the random number generated will not change.
271
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
272
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
273
- with 0 to generate random number.
274
- - Global random seed is not set, operator-level random seed is set: 0
275
- splices with the operator-level random seed to generate random number.
276
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
277
- operator-level random seed to generate random number.
256
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
257
+ to worry about which seed is more important.
278
258
 
279
259
  Args:
280
260
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -380,15 +360,10 @@ class Gamma(PrimitiveWithInfer):
380
360
 
381
361
  Note:
382
362
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
383
- and the random seed is the initial value of this random number. If the random seed is the same in two
363
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
384
364
  separate calls, the random number generated will not change.
385
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
386
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
387
- with 0 to generate random number.
388
- - Global random seed is not set, operator-level random seed is set: 0
389
- splices with the operator-level random seed to generate random number.
390
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
391
- operator-level random seed to generate random number.
365
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
366
+ to worry about which seed is more important.
392
367
 
393
368
  Args:
394
369
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -468,15 +443,10 @@ class ParameterizedTruncatedNormal(Primitive):
468
443
  Note:
469
444
  - The value in tensor `min` must be strictly less than `max` at any position after broadcasting.
470
445
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
471
- and the random seed is the initial value of this random number. If the random seed is the same in two
446
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
472
447
  separate calls, the random number generated will not change.
473
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
474
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
475
- with 0 to generate random number.
476
- - Global random seed is not set, operator-level random seed is set: 0
477
- splices with the operator-level random seed to generate random number.
478
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
479
- operator-level random seed to generate random number.
448
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
449
+ to worry about which seed is more important.
480
450
 
481
451
  Args:
482
452
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -551,15 +521,10 @@ class Poisson(PrimitiveWithInfer):
551
521
 
552
522
  Note:
553
523
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
554
- and the random seed is the initial value of this random number. If the random seed is the same in two
524
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
555
525
  separate calls, the random number generated will not change.
556
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
557
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
558
- with 0 to generate random number.
559
- - Global random seed is not set, operator-level random seed is set: 0
560
- splices with the operator-level random seed to generate random number.
561
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
562
- operator-level random seed to generate random number.
526
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
527
+ to worry about which seed is more important.
563
528
 
564
529
  Args:
565
530
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -630,15 +595,10 @@ class RandomPoisson(Primitive):
630
595
 
631
596
  Note:
632
597
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
633
- and the random seed is the initial value of this random number. If the random seed is the same in two
598
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
634
599
  separate calls, the random number generated will not change.
635
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
636
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
637
- with 0 to generate random number.
638
- - Global random seed is not set, operator-level random seed is set: 0
639
- splices with the operator-level random seed to generate random number.
640
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
641
- operator-level random seed to generate random number.
600
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
601
+ to worry about which seed is more important.
642
602
 
643
603
  Args:
644
604
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -705,15 +665,10 @@ class UniformInt(Primitive):
705
665
  Note:
706
666
  - The number in tensor minval must be strictly less than maxval at any position after broadcasting.
707
667
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
708
- and the random seed is the initial value of this random number. If the random seed is the same in two
668
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
709
669
  separate calls, the random number generated will not change.
710
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
711
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
712
- with 0 to generate random number.
713
- - Global random seed is not set, operator-level random seed is set: 0
714
- splices with the operator-level random seed to generate random number.
715
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
716
- operator-level random seed to generate random number.
670
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
671
+ to worry about which seed is more important.
717
672
 
718
673
  Args:
719
674
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -769,15 +724,16 @@ class UniformReal(Primitive):
769
724
 
770
725
  Note:
771
726
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
772
- and the random seed is the initial value of this random number. If the random seed is the same in two
727
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
773
728
  separate calls, the random number generated will not change.
774
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
775
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
776
- with 0 to generate random number.
777
- - Global random seed is not set, operator-level random seed is set: 0
778
- splices with the operator-level random seed to generate random number.
779
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
780
- operator-level random seed to generate random number.
729
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
730
+ to worry about which seed is more important.
731
+ - Currently, on the Ascend platform, `shape` as a Tensor is not supported.
732
+ This is supported on CPU/GPU platforms. When the input is a Tensor,
733
+ the supported data types are as follows:
734
+
735
+ - GPU: int32, int64.
736
+ - CPU: int16, int32, int64.
781
737
 
782
738
  Args:
783
739
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -787,7 +743,6 @@ class UniformReal(Primitive):
787
743
 
788
744
  Inputs:
789
745
  - **shape** (Union[tuple, Tensor]) - The shape of tensor to be generated. Only constant value is allowed.
790
- Supported dtypes: int16, int32, int64.
791
746
 
792
747
  Outputs:
793
748
  Tensor. The shape that the input 'shape' denotes. The dtype is float32.
@@ -809,6 +764,7 @@ class UniformReal(Primitive):
809
764
  >>> print(result)
810
765
  (2, 2)
811
766
  """
767
+
812
768
  @prim_attr_register
813
769
  def __init__(self, seed=0, seed2=0):
814
770
  """Initialize UniformReal"""
@@ -826,15 +782,10 @@ class RandomChoiceWithMask(Primitive):
826
782
 
827
783
  Note:
828
784
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
829
- and the random seed is the initial value of this random number. If the random seed is the same in two
785
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
830
786
  separate calls, the random number generated will not change.
831
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
832
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
833
- with 0 to generate random number.
834
- - Global random seed is not set, operator-level random seed is set: 0
835
- splices with the operator-level random seed to generate random number.
836
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
837
- operator-level random seed to generate random number.
787
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
788
+ to worry about which seed is more important.
838
789
 
839
790
  Args:
840
791
  count (int, optional): Number of items expected to get and the number must be greater than 0. Default: ``256`` .
@@ -850,8 +801,8 @@ class RandomChoiceWithMask(Primitive):
850
801
  Outputs:
851
802
  Two tensors, the first one is the index tensor and the other one is the mask tensor.
852
803
 
853
- - **index** (Tensor) - The output shape is 2-D.
854
- - **mask** (Tensor) - The output shape is 1-D.
804
+ - **index** (Tensor) - The output shape is 2-D, its shape is :math:`(count, rank of input_x)`.
805
+ - **mask** (Tensor) - The output shape is 1-D, its shape is :math:`(count)`.
855
806
 
856
807
  Supported Platforms:
857
808
  ``Ascend`` ``GPU`` ``CPU``
@@ -945,15 +896,10 @@ class Multinomial(Primitive):
945
896
  - The rows of input do not need to sum to one (in which case we use the values as weights),
946
897
  but must be non-negative, finite and have a non-zero sum.
947
898
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
948
- and the random seed is the initial value of this random number. If the random seed is the same in two
899
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
949
900
  separate calls, the random number generated will not change.
950
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
951
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
952
- with 0 to generate random number.
953
- - Global random seed is not set, operator-level random seed is set: 0
954
- splices with the operator-level random seed to generate random number.
955
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
956
- operator-level random seed to generate random number.
901
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
902
+ to worry about which seed is more important.
957
903
 
958
904
  Args:
959
905
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -1024,8 +970,8 @@ class MultinomialWithReplacement(Primitive):
1024
970
  Inputs:
1025
971
  - **x** (Tensor) - the input tensor containing the cumsum of probabilities, must be 1 or 2
1026
972
  dimensions.
1027
- - **seed** (Tensor) - If `seed` is set to -1, and `offset` is set to 0, the random number
1028
- generator is seeded by a random seed. Otherwise, it is seeded by the given seed.
973
+ - **seed** (Tensor) - If `seed` and 'offset' are both set to 0, the random number generator
974
+ is seeded by a random seed. Otherwise, it is seeded by the given seed and offset.
1029
975
  Supported dtype: int64.
1030
976
  - **offset** (Tensor) - Offset used to avoid seed collision. Supported dtype: int64.
1031
977
 
@@ -1072,7 +1018,9 @@ class UniformCandidateSampler(Primitive):
1072
1018
  range_max (int): The number of possible classes, must be non-negative.
1073
1019
  seed (int, optional): Used for random number generation, must be non-negative. If seed has a value of 0,
1074
1020
  the seed will be replaced with a randomly generated value. Default: ``0`` .
1075
- remove_accidental_hits (bool, optional): Whether accidental hit is removed. Default: ``False`` .
1021
+ remove_accidental_hits (bool, optional): Whether accidental hit is removed.
1022
+ Accidental hit is when one of the true classes matches one of the sample classes.
1023
+ Set ``True`` to remove which accidentally sampling the true class as sample class. Default: ``False`` .
1076
1024
 
1077
1025
  Inputs:
1078
1026
  - **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of
@@ -1128,7 +1076,6 @@ class UniformCandidateSampler(Primitive):
1128
1076
  self.add_prim_attr("side_effect_hidden", True)
1129
1077
 
1130
1078
 
1131
-
1132
1079
  class LogUniformCandidateSampler(Primitive):
1133
1080
  r"""
1134
1081
  Generates random labels with a log-uniform distribution for sampled_candidates.
@@ -1206,15 +1153,10 @@ class RandomShuffle(Primitive):
1206
1153
 
1207
1154
  Note:
1208
1155
  - Random seed: a set of regular random numbers can be obtained through some complex mathematical algorithms,
1209
- and the random seed is the initial value of this random number. If the random seed is the same in two
1156
+ and the random seed determines the initial value of this random number. If the random seed is the same in two
1210
1157
  separate calls, the random number generated will not change.
1211
- - Global random seed and operator-level random seed are not set or both set to 0: behavior is completely random.
1212
- - Global random seed is set, but operator-level random seed is not set: A global random seed will splice
1213
- with 0 to generate random number.
1214
- - Global random seed is not set, operator-level random seed is set: 0
1215
- splices with the operator-level random seed to generate random number.
1216
- - Both Global random and operator-level random seed are set: the global random seed will splice with the
1217
- operator-level random seed to generate random number.
1158
+ - Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
1159
+ to worry about which seed is more important.
1218
1160
 
1219
1161
  Args:
1220
1162
  seed (int, optional): The operator-level random seed, used to generate random numbers,
@@ -1615,7 +1615,7 @@ class SparseMatrixSoftmax(Primitive):
1615
1615
  if not isinstance(dtype, (type(mstype.float32), type(mstype.single), type(mstype.float64),
1616
1616
  type(mstype.double))):
1617
1617
  raise TypeError(
1618
- "Only float32 and float64 type data are supported, but got {}".format(dtype))
1618
+ f"Only float32 and float64 type data are supported, but got {dtype}")
1619
1619
  self.add_prim_attr("dtype", dtype)
1620
1620
  self.init_prim_io_names(inputs=['x_dense_shape', 'x_batch_pointers', 'x_row_pointers',
1621
1621
  'x_col_indices', 'x_values'],
@@ -2602,6 +2602,8 @@ class RaggedTensorToTensor(Primitive):
2602
2602
  raise ValueError(
2603
2603
  f"For {self.name}, the each element of row_partition_types must be 'ROW_SPLITS' "
2604
2604
  f"when row_splits tensor.")
2605
+ self.num_row_partition_tensors = len(row_partition_types)
2606
+ self.add_prim_attr("num_row_partition_tensors", self.num_row_partition_tensors)
2605
2607
 
2606
2608
 
2607
2609
  class SparseCross(Primitive):
@@ -25,7 +25,7 @@ from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
25
25
  from mindspore.common.parameter import Parameter
26
26
  from mindspore.common.api import _pynative_executor
27
27
  from mindspore.common._stub_tensor import _convert_stub
28
- from mindspore._c_expression import Primitive_, prim_type
28
+ from mindspore._c_expression import Primitive_, prim_type, typing
29
29
  from mindspore import _checkparam as Validator
30
30
  from mindspore.ops import signature as sig
31
31
 
@@ -746,6 +746,8 @@ def _check_contains_variable(item_dtype, item_value):
746
746
  if _check_contains_variable(item_dtype[i], element):
747
747
  return True
748
748
  elif isinstance(item_value, dict):
749
+ if isinstance(item_dtype, typing.Keyword):
750
+ return item_value is None
749
751
  for i in range(len(item_value)):
750
752
  if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
751
753
  return True
@@ -756,9 +758,7 @@ def _check_contains_variable(item_dtype, item_value):
756
758
 
757
759
 
758
760
  def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
759
- """
760
- Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
761
- to compute constant value using the constants in the constructor.
761
+ """Used to calculate constant in graph copmpiling process and improve compile performance in GRAPH_MODE.
762
762
 
763
763
  Args:
764
764
  fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` .
@@ -772,22 +772,27 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
772
772
  and the warning message will raised if the parameter is not const value. Default: ``True`` .
773
773
 
774
774
  Examples:
775
- >>> from mindspore.ops import constexpr
776
- >>> a = (1, 2)
777
- >>> # make an operator to calculate tuple len
778
- >>> @constexpr
779
- ... def tuple_len(x):
780
- ... return len(x)
775
+
776
+ >>> import mindspore as ms
777
+ >>> # define a constant calculate function with for loop inside and use use constexpr to accelerate the compile
778
+ >>> # process.
779
+ >>> @ms.constexpr
780
+ ... def for_loop_calculate(range_num):
781
+ ... out = 0
782
+ ... for i in range(range_num):
783
+ ... if i %2 == 0 and i % 7 != 0:
784
+ ... out = out + i
785
+ ... return out // range_num
781
786
  ...
782
- >>> print(tuple_len(a))
783
- 2
784
- >>> # make an operator class to calculate tuple len
785
- >>> @constexpr(get_instance=False, name="TupleLen")
786
- ... def tuple_len_class(x):
787
- ... return len(x)
787
+ >>> # construct a net and run with GRAPH_MODE.
788
+ >>> @ms.jit
789
+ ... def my_func(x):
790
+ ... new_shape = for_loop_calculate(100000)
791
+ ... return ms.ops.broadcast_to(x, (new_shape, ))
788
792
  ...
789
- >>> print(tuple_len_class()(a))
790
- 2
793
+ >>> out = my_func(ms.Tensor([1]))
794
+ >>> print(out.shape)
795
+ >>> (21428, )
791
796
  """
792
797
 
793
798
  def deco(fn):
@@ -844,6 +849,7 @@ def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
844
849
  reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time,
845
850
  otherwise the operator will always be executed. Default: ``True`` .
846
851
  """
852
+
847
853
  def deco(fn):
848
854
  """Decorator for CompileOp."""
849
855
 
@@ -62,6 +62,7 @@ class _ParallelOptimizerConfig:
62
62
  """
63
63
  GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
64
64
  PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
65
+ OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
65
66
 
66
67
 
67
68
  class _AutoParallelContext:
@@ -176,7 +177,6 @@ class _AutoParallelContext:
176
177
  if comm_type == _ParallelFusionConfig.REDUCESCATTER:
177
178
  self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
178
179
 
179
-
180
180
  def fusion_threshold_mb(self):
181
181
  """Get all reduce threshold."""
182
182
  self.check_context_handle()
@@ -229,6 +229,22 @@ class _AutoParallelContext:
229
229
  self.check_context_handle()
230
230
  return self._context_handle.get_pipeline_stage_split_num()
231
231
 
232
+ def set_pipeline_segments(self, segments):
233
+ """Set the segments of the pipeline"""
234
+ if isinstance(segments, bool) or not isinstance(segments, int):
235
+ raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
236
+ "must be int, but got the type : {}.".format(type(segments)))
237
+ if segments < 1:
238
+ raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
239
+ "should be greater or equal 1, but got the value of segments : {}.".format(segments))
240
+ self.check_context_handle()
241
+ self._context_handle.set_pipeline_segment_split_num(segments)
242
+
243
+ def get_pipeline_segments(self):
244
+ """Get the stages of the pipeline"""
245
+ self.check_context_handle()
246
+ return self._context_handle.get_pipeline_segment_split_num()
247
+
232
248
  def set_gradients_mean(self, gradients_mean):
233
249
  """
234
250
  Set gradients_mean flag.
@@ -491,6 +507,9 @@ class _AutoParallelContext:
491
507
  Args:
492
508
  grad_accumulation_step (int): The grad accumulation step.
493
509
  """
510
+ if grad_accumulation_step > 1:
511
+ raise ValueError("The interface is deprecated. To use gradient accumulation, "
512
+ "please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.")
494
513
  self.check_context_handle()
495
514
  Validator.check_positive_int(grad_accumulation_step)
496
515
  self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
@@ -758,6 +777,11 @@ class _AutoParallelContext:
758
777
  .format(type(enable_parallel_optimizer)))
759
778
  self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
760
779
 
780
+ def get_enable_fold_pipeline(self):
781
+ """Get parallel optimizer flag."""
782
+ self.check_context_handle()
783
+ return self._context_handle.get_enable_fold_pipeline()
784
+
761
785
  def get_enable_parallel_optimizer(self):
762
786
  """Get parallel optimizer flag."""
763
787
  self.check_context_handle()
@@ -767,8 +791,6 @@ class _AutoParallelContext:
767
791
  r"""
768
792
  Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
769
793
  training when parallel optimizer is enabled.
770
- Currently it supports the key `gradient_accumulation_shard`. The configure will be effective
771
- when we use context.set_auto_parallel_context(enable_parallel_optimizer=True).
772
794
 
773
795
  Args:
774
796
  parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
@@ -786,14 +808,21 @@ class _AutoParallelContext:
786
808
  enabled, parameters with size smaller than this threshold will not be
787
809
  sharded across the devices. Parameter size = shape[0] \* ... \*
788
810
  shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
811
+ - optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
812
+ maximum group size across devices when the parallel optimizer is
813
+ enabled. The numerical range can be (0, device_num]. Default value
814
+ is -1, which means the optimizer weight shard group size will
815
+ the data parallel group of each parameter. Default -1.
816
+
789
817
  """
790
818
  self.check_context_handle()
791
819
  grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
792
820
  threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
821
+ optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
793
822
 
794
823
  for config_name in parallel_optimizer_config:
795
824
  unknown_config = []
796
- if config_name not in [grad_shard_name, threshold_name]:
825
+ if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]:
797
826
  unknown_config.append(config_name)
798
827
 
799
828
  if unknown_config:
@@ -811,6 +840,11 @@ class _AutoParallelContext:
811
840
  self._context_handle.set_parallel_optimizer_threshold(
812
841
  parallel_optimizer_config[threshold_name])
813
842
 
843
+ if optimizer_weight_shard_size_name in parallel_optimizer_config:
844
+ value = parallel_optimizer_config[optimizer_weight_shard_size_name]
845
+ Validator.check_positive_int(value)
846
+ self.set_optimizer_weight_shard_size(value)
847
+
814
848
  def get_grad_accumulation_shard(self):
815
849
  """Get grad accumulation shard."""
816
850
  self.check_context_handle()
@@ -890,6 +924,13 @@ class _AutoParallelContext:
890
924
  self.check_context_handle()
891
925
  return self._context_handle.get_optimizer_weight_shard_size()
892
926
 
927
+ def set_ops_strategy_json_config(self, type, path, mode):
928
+ """
929
+ Set configuration of saving ops strategy in file .json.
930
+ """
931
+ self.check_context_handle()
932
+ self._context_handle.set_ops_strategy_json_config(type, path, mode)
933
+
893
934
  def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
894
935
  """
895
936
  Set optimizer_weight_shard_aggregated_save.
@@ -1027,8 +1068,28 @@ class _AutoParallelContext:
1027
1068
  self.set_enable_all_gather_fusion(openstate)
1028
1069
  self.set_enable_reduce_scatter_fusion(openstate)
1029
1070
 
1071
+ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1072
+ """
1073
+ Set strategy json configuration.
1030
1074
 
1075
+ Args:
1076
+ type (str): The parameter for choosing save or load .json file.
1077
+ path (str): Path to save or load parallel strategy json.
1078
+ mode (str): The parameter for choosing save all or important operators.
1031
1079
 
1080
+ Raises:
1081
+ KeyError: When type is not 'SAVE' or 'LOAD'.
1082
+ KeyError: When mode is not 'all' or 'principal'.
1083
+ """
1084
+ dir_path = os.path.dirname(path)
1085
+ if dir_path and not os.path.exists(dir_path):
1086
+ os.makedirs(dir_path)
1087
+ check_type = ["SAVE", "LOAD"]
1088
+ check_mode = ["all", "principal"]
1089
+ if type in check_type and mode in check_mode:
1090
+ auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
1091
+ else:
1092
+ raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
1032
1093
 
1033
1094
  _AUTO_PARALLEL_CONTEXT = None
1034
1095
 
@@ -1053,6 +1114,7 @@ _set_auto_parallel_context_func_map = {
1053
1114
  "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
1054
1115
  "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
1055
1116
  "pipeline_stages": auto_parallel_context().set_pipeline_stages,
1117
+ "pipeline_segments": auto_parallel_context().set_pipeline_segments,
1056
1118
  "parallel_mode": auto_parallel_context().set_parallel_mode,
1057
1119
  "search_mode": auto_parallel_context().set_strategy_search_mode,
1058
1120
  "auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
@@ -1074,7 +1136,6 @@ _set_auto_parallel_context_func_map = {
1074
1136
  "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1075
1137
  "comm_fusion": auto_parallel_context().set_comm_fusion}
1076
1138
 
1077
-
1078
1139
  _get_auto_parallel_context_func_map = {
1079
1140
  "device_num": auto_parallel_context().get_device_num,
1080
1141
  "global_rank": auto_parallel_context().get_global_rank,
@@ -1111,7 +1172,6 @@ _get_auto_parallel_context_func_map = {
1111
1172
  communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
1112
1173
  optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
1113
1174
  strategy_ckpt_config=dict)
1114
-
1115
1175
  def _set_auto_parallel_context(**kwargs):
1116
1176
  """
1117
1177
  Set auto parallel context.
@@ -1247,8 +1307,8 @@ def _reset_auto_parallel_context():
1247
1307
  - strategy_ckpt_load_file: ""
1248
1308
  - strategy_ckpt_save_file: ""
1249
1309
  - enable_parallel_optimizer: False
1250
- - search_mode: dynamic_programming
1251
- - auto_parallel_search_mode: dynamic_programming
1310
+ - search_mode: 'recursive_programming
1311
+ - auto_parallel_search_mode: 'recursive_programming
1252
1312
  - sharding_propagation: False
1253
1313
  - pipeline_stages: 0
1254
1314
  - gradient_accumulation_shard: True
@@ -475,7 +475,7 @@ class _CostModelContext:
475
475
  """
476
476
  if self._context_handle is None:
477
477
  raise ValueError("Context handle is none in context!!!")
478
- return self._context_handle.rp_matmul_mem_coef()
478
+ return self._context_handle.get_rp_matmul_mem_coef()
479
479
 
480
480
  def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
481
481
  """
@@ -693,7 +693,7 @@ def _set_rp_matmul_mem_coef(coef):
693
693
  cost_model_context().set_rp_matmul_mem_coef(coef)
694
694
 
695
695
 
696
- def _get_rp_matmul_mem_coef(self):
696
+ def _get_rp_matmul_mem_coef():
697
697
  """
698
698
  Get the matmul memory coef which is used in the RP algorithm.
699
699
  """