mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -62,6 +62,7 @@ class _ParallelOptimizerConfig:
62
62
  """
63
63
  GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
64
64
  PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
65
+ OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
65
66
 
66
67
 
67
68
  class _AutoParallelContext:
@@ -176,7 +177,6 @@ class _AutoParallelContext:
176
177
  if comm_type == _ParallelFusionConfig.REDUCESCATTER:
177
178
  self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
178
179
 
179
-
180
180
  def fusion_threshold_mb(self):
181
181
  """Get all reduce threshold."""
182
182
  self.check_context_handle()
@@ -229,6 +229,22 @@ class _AutoParallelContext:
229
229
  self.check_context_handle()
230
230
  return self._context_handle.get_pipeline_stage_split_num()
231
231
 
232
+ def set_pipeline_segments(self, segments):
233
+ """Set the segments of the pipeline"""
234
+ if isinstance(segments, bool) or not isinstance(segments, int):
235
+ raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
236
+ "must be int, but got the type : {}.".format(type(segments)))
237
+ if segments < 1:
238
+ raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
239
+ "should be greater or equal 1, but got the value of segments : {}.".format(segments))
240
+ self.check_context_handle()
241
+ self._context_handle.set_pipeline_segment_split_num(segments)
242
+
243
+ def get_pipeline_segments(self):
244
+ """Get the stages of the pipeline"""
245
+ self.check_context_handle()
246
+ return self._context_handle.get_pipeline_segment_split_num()
247
+
232
248
  def set_gradients_mean(self, gradients_mean):
233
249
  """
234
250
  Set gradients_mean flag.
@@ -370,7 +386,7 @@ class _AutoParallelContext:
370
386
  """
371
387
  Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
372
388
  will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
373
- will search the desired strategies. Default: False.
389
+ will search the desired strategies. Default: ``False``.
374
390
  This attribute is replaced by context.set_auto_parallel_context(search_mode="sharding_propagation").
375
391
 
376
392
  Args:
@@ -491,6 +507,9 @@ class _AutoParallelContext:
491
507
  Args:
492
508
  grad_accumulation_step (int): The grad accumulation step.
493
509
  """
510
+ if grad_accumulation_step > 1:
511
+ raise ValueError("The interface is deprecated. To use gradient accumulation, "
512
+ "please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.")
494
513
  self.check_context_handle()
495
514
  Validator.check_positive_int(grad_accumulation_step)
496
515
  self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
@@ -758,6 +777,11 @@ class _AutoParallelContext:
758
777
  .format(type(enable_parallel_optimizer)))
759
778
  self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
760
779
 
780
+ def get_enable_fold_pipeline(self):
781
+ """Get parallel optimizer flag."""
782
+ self.check_context_handle()
783
+ return self._context_handle.get_enable_fold_pipeline()
784
+
761
785
  def get_enable_parallel_optimizer(self):
762
786
  """Get parallel optimizer flag."""
763
787
  self.check_context_handle()
@@ -767,8 +791,6 @@ class _AutoParallelContext:
767
791
  r"""
768
792
  Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
769
793
  training when parallel optimizer is enabled.
770
- Currently it supports the key `gradient_accumulation_shard`. The configure will be effective
771
- when we use context.set_auto_parallel_context(enable_parallel_optimizer=True).
772
794
 
773
795
  Args:
774
796
  parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
@@ -786,14 +808,21 @@ class _AutoParallelContext:
786
808
  enabled, parameters with size smaller than this threshold will not be
787
809
  sharded across the devices. Parameter size = shape[0] \* ... \*
788
810
  shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
811
+ - optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
812
+ maximum group size across devices when the parallel optimizer is
813
+ enabled. The numerical range can be (0, device_num]. Default value
814
+ is -1, which means the optimizer weight shard group size will
815
+ the data parallel group of each parameter. Default -1.
816
+
789
817
  """
790
818
  self.check_context_handle()
791
819
  grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
792
820
  threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
821
+ optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
793
822
 
794
823
  for config_name in parallel_optimizer_config:
795
824
  unknown_config = []
796
- if config_name not in [grad_shard_name, threshold_name]:
825
+ if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]:
797
826
  unknown_config.append(config_name)
798
827
 
799
828
  if unknown_config:
@@ -811,6 +840,11 @@ class _AutoParallelContext:
811
840
  self._context_handle.set_parallel_optimizer_threshold(
812
841
  parallel_optimizer_config[threshold_name])
813
842
 
843
+ if optimizer_weight_shard_size_name in parallel_optimizer_config:
844
+ value = parallel_optimizer_config[optimizer_weight_shard_size_name]
845
+ Validator.check_positive_int(value)
846
+ self.set_optimizer_weight_shard_size(value)
847
+
814
848
  def get_grad_accumulation_shard(self):
815
849
  """Get grad accumulation shard."""
816
850
  self.check_context_handle()
@@ -824,7 +858,7 @@ class _AutoParallelContext:
824
858
  def set_enable_alltoall(self, enable_a2a):
825
859
  """
826
860
  Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
827
- Default: False.
861
+ Default: ``False``.
828
862
 
829
863
  Args:
830
864
  enable_a2a (bool): Enable/disable AllToAll.
@@ -890,6 +924,13 @@ class _AutoParallelContext:
890
924
  self.check_context_handle()
891
925
  return self._context_handle.get_optimizer_weight_shard_size()
892
926
 
927
+ def set_ops_strategy_json_config(self, type, path, mode):
928
+ """
929
+ Set configuration of saving ops strategy in file .json.
930
+ """
931
+ self.check_context_handle()
932
+ self._context_handle.set_ops_strategy_json_config(type, path, mode)
933
+
893
934
  def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
894
935
  """
895
936
  Set optimizer_weight_shard_aggregated_save.
@@ -1027,7 +1068,28 @@ class _AutoParallelContext:
1027
1068
  self.set_enable_all_gather_fusion(openstate)
1028
1069
  self.set_enable_reduce_scatter_fusion(openstate)
1029
1070
 
1071
+ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1072
+ """
1073
+ Set strategy json configuration.
1030
1074
 
1075
+ Args:
1076
+ type (str): The parameter for choosing save or load .json file.
1077
+ path (str): Path to save or load parallel strategy json.
1078
+ mode (str): The parameter for choosing save all or important operators.
1079
+
1080
+ Raises:
1081
+ KeyError: When type is not 'SAVE' or 'LOAD'.
1082
+ KeyError: When mode is not 'all' or 'principal'.
1083
+ """
1084
+ dir_path = os.path.dirname(path)
1085
+ if dir_path and not os.path.exists(dir_path):
1086
+ os.makedirs(dir_path)
1087
+ check_type = ["SAVE", "LOAD"]
1088
+ check_mode = ["all", "principal"]
1089
+ if type in check_type and mode in check_mode:
1090
+ auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
1091
+ else:
1092
+ raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
1031
1093
 
1032
1094
  _AUTO_PARALLEL_CONTEXT = None
1033
1095
 
@@ -1052,6 +1114,7 @@ _set_auto_parallel_context_func_map = {
1052
1114
  "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
1053
1115
  "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
1054
1116
  "pipeline_stages": auto_parallel_context().set_pipeline_stages,
1117
+ "pipeline_segments": auto_parallel_context().set_pipeline_segments,
1055
1118
  "parallel_mode": auto_parallel_context().set_parallel_mode,
1056
1119
  "search_mode": auto_parallel_context().set_strategy_search_mode,
1057
1120
  "auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
@@ -1073,7 +1136,6 @@ _set_auto_parallel_context_func_map = {
1073
1136
  "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1074
1137
  "comm_fusion": auto_parallel_context().set_comm_fusion}
1075
1138
 
1076
-
1077
1139
  _get_auto_parallel_context_func_map = {
1078
1140
  "device_num": auto_parallel_context().get_device_num,
1079
1141
  "global_rank": auto_parallel_context().get_global_rank,
@@ -1110,7 +1172,6 @@ _get_auto_parallel_context_func_map = {
1110
1172
  communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
1111
1173
  optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
1112
1174
  strategy_ckpt_config=dict)
1113
-
1114
1175
  def _set_auto_parallel_context(**kwargs):
1115
1176
  """
1116
1177
  Set auto parallel context.
@@ -1121,11 +1182,11 @@ def _set_auto_parallel_context(**kwargs):
1121
1182
  Args:
1122
1183
  device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
1123
1184
  global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
1124
- gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
1185
+ gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: ``False``.
1125
1186
  loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
1126
- calculations. Default: True.
1187
+ calculations. Default: ``True``.
1127
1188
  gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
1128
- Default: True.
1189
+ Default: ``True``.
1129
1190
  parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
1130
1191
  "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
1131
1192
 
@@ -1151,13 +1212,13 @@ def _set_auto_parallel_context(**kwargs):
1151
1212
  for forward compatibility, and this attribute will be deleted in a future MindSpore version.
1152
1213
  parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
1153
1214
  "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
1154
- broadcast. Default: False.
1215
+ broadcast. Default: ``False``.
1155
1216
  strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
1156
1217
  strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
1157
1218
  group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
1158
- full_batch (bool): Whether to load the whole batch on each device. Default: False.
1219
+ full_batch (bool): Whether to load the whole batch on each device. Default: ``False``.
1159
1220
  dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
1160
- enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
1221
+ enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: ``False``.
1161
1222
  all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
1162
1223
  pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
1163
1224
  the devices are distributed alone the pipeline. The total devices will be divided into
@@ -1175,13 +1236,13 @@ def _set_auto_parallel_context(**kwargs):
1175
1236
  It should be larger than one and less than or equal with the data parallel size.
1176
1237
  Default: -1, which means fully use parallel optimizer in data parallel dimension.
1177
1238
  optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
1178
- optimizer. Default: False.
1239
+ optimizer. Default: ``False``.
1179
1240
  sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
1180
1241
  the strategy-configured operators will propagate the strategies to other
1181
1242
  operators with minimum redistribution cost; otherwise, the algorithm will
1182
- search the desired strategies. Default: False.
1243
+ search the desired strategies. Default: ``False``.
1183
1244
  enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
1184
- circumvent AllToAll. Default: False.
1245
+ circumvent AllToAll. Default: ``False``.
1185
1246
  comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
1186
1247
  communication fusion config has two keys: "mode" and "config".
1187
1248
  It supports following communication fusion types and configurations:
@@ -1246,8 +1307,8 @@ def _reset_auto_parallel_context():
1246
1307
  - strategy_ckpt_load_file: ""
1247
1308
  - strategy_ckpt_save_file: ""
1248
1309
  - enable_parallel_optimizer: False
1249
- - search_mode: dynamic_programming
1250
- - auto_parallel_search_mode: dynamic_programming
1310
+ - search_mode: 'recursive_programming
1311
+ - auto_parallel_search_mode: 'recursive_programming
1251
1312
  - sharding_propagation: False
1252
1313
  - pipeline_stages: 0
1253
1314
  - gradient_accumulation_shard: True
@@ -452,6 +452,31 @@ class _CostModelContext:
452
452
  raise ValueError("Context handle is none in context!!!")
453
453
  return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
454
454
 
455
+ def set_rp_matmul_mem_coef(self, coef):
456
+ """
457
+ Set the matmul memory coef which is used in the RP algorithm.
458
+
459
+ Args:
460
+ coef (int): The coefficient of memory cost in RP.
461
+
462
+ Raises:
463
+ ValueError: If context handle is none.
464
+ """
465
+ if self._context_handle is None:
466
+ raise ValueError("Context handle is none in context!!!")
467
+ self._context_handle.set_rp_matmul_mem_coef(coef)
468
+
469
+ def get_rp_matmul_mem_coef(self):
470
+ """
471
+ Get the matmul memory coef which is used in the RP algorithm.
472
+
473
+ Raises:
474
+ ValueError: If context handle is none.
475
+ """
476
+ if self._context_handle is None:
477
+ raise ValueError("Context handle is none in context!!!")
478
+ return self._context_handle.get_rp_matmul_mem_coef()
479
+
455
480
  def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
456
481
  """
457
482
  Set costmodel allreduce fusion computation time parameter.
@@ -656,3 +681,20 @@ def _get_algo_single_loop():
656
681
  Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
657
682
  """
658
683
  return cost_model_context().get_dp_algo_single_loop()
684
+
685
+
686
+ def _set_rp_matmul_mem_coef(coef):
687
+ """
688
+ Set the matmul memory coef which is used in the RP algorithm.
689
+
690
+ Args:
691
+ coef (int): The coefficient of memory cost in RP.
692
+ """
693
+ cost_model_context().set_rp_matmul_mem_coef(coef)
694
+
695
+
696
+ def _get_rp_matmul_mem_coef():
697
+ """
698
+ Get the matmul memory coef which is used in the RP algorithm.
699
+ """
700
+ return cost_model_context().get_rp_matmul_mem_coef()
@@ -22,21 +22,27 @@ from mindspore._c_expression import OffloadContext
22
22
  from mindspore._checkparam import args_type_check
23
23
  from mindspore import _checkparam as Validator
24
24
 
25
+ K_RE_PATTERN = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
26
+ K_GBTOBYTE = 1 << 30
27
+
25
28
 
26
29
  class _OffloadConfig:
27
30
  """
28
31
  The key of the Offload Config.
29
32
  """
30
- ENABLE_OFFLOAD = "enable_offload"
31
33
  OFFLOAD_PARAM = "offload_param"
32
34
  OFFLOAD_PATH = "offload_path"
35
+ OFFLOAD_CPU_SIZE = "offload_cpu_size"
33
36
  OFFLOAD_CHECKPOINT = "offload_checkpoint"
34
- OFFLOAD_DDR_SIZE = "offload_ddr_size"
35
37
  OFFLOAD_DISK_SIZE = "offload_disk_size"
36
38
  ENABLE_AIO = "enable_aio"
37
39
  AIO_BLOCK_SIZE = "aio_block_size"
38
40
  AIO_QUEUE_DEPTH = "aio_queue_depth"
39
41
  ENABLE_PINNED_MEM = "enable_pinned_mem"
42
+ AUTO_OFFLOAD = "auto_offload"
43
+ CPU_RATIO = "cpu_ratio"
44
+ HBM_RATIO = "hbm_ratio"
45
+ HOST_MEM_BLOCk_SIZE = "host_mem_block_size"
40
46
 
41
47
 
42
48
  class _OffloadContext:
@@ -70,97 +76,164 @@ class _OffloadContext:
70
76
  if self._context_handle is None:
71
77
  raise ValueError("Context handle is none in context!!!")
72
78
 
79
+ def set_offload_param(self, offload_param):
80
+ """Set offload_param"""
81
+ if not isinstance(offload_param, str):
82
+ raise TypeError("For 'set_offload_param', "
83
+ "the argument 'offload_param' must be str, but got the type : {}."
84
+ .format(type(offload_param)))
85
+ Validator.check_string(offload_param.lower(), ["cpu", "disk"], "offload_param", "set_offload_param")
86
+ self._context_handle.set_offload_param(offload_param.lower())
87
+
88
+ def set_offload_checkpoint(self, offload_checkpoint):
89
+ """Set offload_checkpoint"""
90
+ if not isinstance(offload_checkpoint, str):
91
+ raise TypeError("For 'set_offload_checkpoint', "
92
+ "the argument 'offload_checkpoint' must be str, but got the type : {}."
93
+ .format(type(offload_checkpoint)))
94
+ Validator.check_string(offload_checkpoint.lower(), ["cpu", "disk"], "offload_checkpoint",
95
+ "set_offload_checkpoint")
96
+ self._context_handle.set_offload_checkpoint(offload_checkpoint.lower())
97
+
98
+ def set_offload_path(self, offload_path):
99
+ """Set offload_path"""
100
+ if not isinstance(offload_path, str):
101
+ raise TypeError("For 'set_offload_path', "
102
+ "the argument 'offload_path' must be str, but got the type : {}."
103
+ .format(type(offload_path)))
104
+ self._context_handle.set_offload_path(offload_path)
105
+
106
+ def set_offload_cpu_size(self, offload_cpu_size):
107
+ """Set offload_cpu_size"""
108
+ if not isinstance(offload_cpu_size, str):
109
+ raise TypeError("For 'set_offload_cpu_size', "
110
+ "the argument 'offload_cpu_size' must be str, but got the type : {}."
111
+ .format(type(offload_cpu_size)))
112
+ if not Validator.check_str_by_regular(offload_cpu_size, K_RE_PATTERN):
113
+ raise ValueError("The argument 'offload_cpu_size' should be in correct "
114
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
115
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
116
+ .format(offload_cpu_size))
117
+ ddr_size = float(offload_cpu_size[:-2])
118
+ self._context_handle.set_offload_cpu_size(int(ddr_size * K_GBTOBYTE))
119
+
120
+ def set_offload_disk_size(self, offload_disk_size):
121
+ """Set offload_disk_size"""
122
+ if not isinstance(offload_disk_size, str):
123
+ raise TypeError("For 'set_offload_disk_size', "
124
+ "the argument 'offload_disk_size' must be str, but got the type : {}."
125
+ .format(type(offload_disk_size)))
126
+ if not Validator.check_str_by_regular(offload_disk_size, K_RE_PATTERN):
127
+ raise ValueError("The argument 'offload_disk_size' should be in correct "
128
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
129
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
130
+ .format(offload_disk_size))
131
+ disk_size = float(offload_disk_size[:-2])
132
+ self._context_handle.set_offload_disk_size(int(disk_size * K_GBTOBYTE))
133
+
134
+ def set_enable_aio(self, enable_aio):
135
+ """Set enable_aio"""
136
+ Validator.check_bool(enable_aio, "enable_aio", "set_enable_aio")
137
+ self._context_handle.set_enable_aio(enable_aio)
138
+
139
+ def set_aio_block_size(self, aio_block_size):
140
+ """Set aio_block_size"""
141
+ if not isinstance(aio_block_size, str):
142
+ raise TypeError("For 'set_aio_block_size', "
143
+ "the argument 'aio_block_size' must be str, but got the type : {}."
144
+ .format(type(aio_block_size)))
145
+ if not Validator.check_str_by_regular(aio_block_size, K_RE_PATTERN):
146
+ raise ValueError("The argument 'aio_block_size' should be in correct "
147
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
148
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
149
+ .format(aio_block_size))
150
+ aio_size = float(aio_block_size[:-2])
151
+ self._context_handle.set_aio_block_size(int(aio_size * K_GBTOBYTE))
152
+
153
+ def set_aio_queue_depth(self, aio_queue_depth):
154
+ """Set aio_queue_depth"""
155
+ Validator.check_positive_int(
156
+ aio_queue_depth, "aio_queue_depth", "set_aio_queue_depth")
157
+ self._context_handle.set_aio_queue_depth(aio_queue_depth)
158
+
159
+ def set_enable_pinned_mem(self, enable_pinned_mem):
160
+ """Set enable_pinned_mem"""
161
+ Validator.check_bool(
162
+ enable_pinned_mem, "enable_pinned_mem", "set_enable_pinned_mem")
163
+ self._context_handle.set_enable_pinned_mem(enable_pinned_mem)
164
+
165
+ def set_auto_offload(self, auto_offload):
166
+ """Set auto_offload"""
167
+ Validator.check_bool(auto_offload, "auto_offload", "set_auto_offload")
168
+ self._context_handle.set_auto_offload(auto_offload)
169
+
170
+ def set_host_mem_block_size(self, host_mem_block_size):
171
+ """Set host_mem_block_size"""
172
+ if not isinstance(host_mem_block_size, str):
173
+ raise TypeError("For 'set_host_mem_block_size', "
174
+ "the argument 'host_mem_block_size' must be str, but got the type : {}."
175
+ .format(type(host_mem_block_size)))
176
+ if not Validator.check_str_by_regular(host_mem_block_size, K_RE_PATTERN):
177
+ raise ValueError("The argument 'host_mem_block_size' should be in correct "
178
+ " format! It must be a string ending with 'GB', in addition to that, it must contain "
179
+ "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
180
+ .format(host_mem_block_size))
181
+ block_size = float(host_mem_block_size[:-2])
182
+ self._context_handle.set_host_mem_block_size(
183
+ int(block_size * K_GBTOBYTE))
184
+
185
+ def set_cpu_ratio(self, cpu_ratio):
186
+ """Set cpu_ratio"""
187
+ Validator.check_float_range(
188
+ cpu_ratio, 0, 1, Validator.INC_RIGHT, 'cpu_ratio')
189
+ self._context_handle.set_cpu_ratio(cpu_ratio)
190
+
191
+ def set_hbm_ratio(self, hbm_ratio):
192
+ """Set hbm_ratio"""
193
+ Validator.check_float_range(
194
+ hbm_ratio, 0, 1, Validator.INC_RIGHT, 'hbm_ratio')
195
+ self._context_handle.set_hbm_ratio(hbm_ratio)
196
+
73
197
  def set_offload_config(self, offload_config):
74
198
  """Set offfload context"""
75
199
  self.check_context_handle()
76
- enable_offload = _OffloadConfig.ENABLE_OFFLOAD
77
- offload_param = _OffloadConfig.OFFLOAD_PARAM
78
- offload_path = _OffloadConfig.OFFLOAD_PATH
79
- offload_checkpoint = _OffloadConfig.OFFLOAD_CHECKPOINT
80
- offload_ddr_size = _OffloadConfig.OFFLOAD_DDR_SIZE
81
- offload_disk_size = _OffloadConfig.OFFLOAD_DISK_SIZE
82
- enable_aio = _OffloadConfig.ENABLE_AIO
83
- aio_block_size = _OffloadConfig.AIO_BLOCK_SIZE
84
- aio_queue_depth = _OffloadConfig.AIO_QUEUE_DEPTH
85
- enable_pinned_mem = _OffloadConfig.ENABLE_PINNED_MEM
86
200
 
87
201
  for config_name in offload_config:
88
202
  unknown_config = []
89
- if config_name not in [enable_offload, offload_param, offload_path, offload_checkpoint,
90
- offload_ddr_size, offload_disk_size, enable_aio, aio_block_size,
91
- aio_queue_depth, enable_pinned_mem]:
203
+ if config_name not in [_OffloadConfig.OFFLOAD_PARAM, _OffloadConfig.OFFLOAD_PATH,
204
+ _OffloadConfig.CPU_RATIO, _OffloadConfig.HOST_MEM_BLOCk_SIZE,
205
+ _OffloadConfig.HBM_RATIO, _OffloadConfig.OFFLOAD_CPU_SIZE,
206
+ _OffloadConfig.OFFLOAD_DISK_SIZE, _OffloadConfig.ENABLE_AIO,
207
+ _OffloadConfig.AIO_BLOCK_SIZE, _OffloadConfig.AIO_QUEUE_DEPTH,
208
+ _OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD,
209
+ _OffloadConfig.OFFLOAD_CHECKPOINT]:
92
210
  unknown_config.append(config_name)
93
211
 
94
212
  if unknown_config:
95
213
  raise ValueError("Unknown config: {}".format(unknown_config))
96
-
97
- if enable_offload in offload_config:
98
- Validator.check_bool(
99
- offload_config[enable_offload], enable_offload, enable_offload)
100
- self._context_handle.set_enable_offload(
101
- offload_config[enable_offload])
102
-
103
- if offload_param in offload_config:
104
- Validator.check_string(
105
- offload_config[offload_param].lower(), ["cpu", "disk"])
106
- self._context_handle.set_offload_param(
107
- offload_config[offload_param].lower())
108
-
109
- if offload_path in offload_config:
110
- if not isinstance(offload_config[offload_path], str):
111
- raise TypeError("For 'set_offload_path', "
112
- "the argument 'offload_path' must be str, but got the type : {}."
113
- .format(type(offload_config[offload_path])))
114
- self._context_handle.set_offload_path(
115
- offload_config[offload_path])
116
- if offload_checkpoint in offload_config:
117
- Validator.check_string(
118
- offload_config[offload_checkpoint].lower(), ["cpu", "disk"])
119
- self._context_handle.set_offload_checkpoint(
120
- offload_config[offload_checkpoint].lower())
121
-
122
- if offload_ddr_size in offload_config:
123
- Validator.check_positive_int(offload_config[offload_ddr_size])
124
- self._context_handle.set_offload_ddr_size(
125
- offload_config[offload_ddr_size])
126
-
127
- if offload_disk_size in offload_config:
128
- Validator.check_positive_int(offload_config[offload_disk_size])
129
- self._context_handle.set_offload_disk_size(
130
- offload_config[offload_disk_size])
131
- if enable_aio in offload_config:
132
- Validator.check_bool(
133
- offload_config[enable_aio], enable_aio, enable_aio)
134
- self._context_handle.set_enable_aio(
135
- offload_config[enable_aio])
136
- if aio_block_size in offload_config:
137
- Validator.check_positive_int(offload_config[aio_block_size])
138
- self._context_handle.set_aio_block_size(
139
- offload_config[aio_block_size])
140
- if aio_queue_depth in offload_config:
141
- Validator.check_positive_int(offload_config[aio_queue_depth])
142
- self._context_handle.set_aio_queue_depth(
143
- offload_config[aio_queue_depth])
144
- if enable_pinned_mem in offload_config:
145
- Validator.check_bool(
146
- offload_config[enable_pinned_mem], enable_pinned_mem, enable_pinned_mem)
147
- self._context_handle.set_enable_pinned_mem(
148
- offload_config[enable_pinned_mem])
214
+ func = _set_offload_context_func_map.get(config_name, None)
215
+ if not func:
216
+ raise ValueError(
217
+ "Can not find set function: {}".format(config_name))
218
+ func(offload_config[config_name])
149
219
 
150
220
  def offload_config(self):
151
221
  """Get config of offload"""
152
222
  self.check_context_handle()
153
223
  offload_config = {
154
- _OffloadConfig.ENABLE_OFFLOAD: self._context_handle.enable_offload(),
155
224
  _OffloadConfig.OFFLOAD_PARAM: self._context_handle.offload_param(),
156
225
  _OffloadConfig.OFFLOAD_PATH: self._context_handle.offload_path(),
157
- _OffloadConfig.OFFLOAD_CHECKPOINT: self._context_handle.offload_checkpoint(),
158
- _OffloadConfig.OFFLOAD_DDR_SIZE: self._context_handle.offload_ddr_size(),
226
+ _OffloadConfig.OFFLOAD_CPU_SIZE: self._context_handle.offload_cpu_size(),
159
227
  _OffloadConfig.OFFLOAD_DISK_SIZE: self._context_handle.offload_disk_size(),
160
228
  _OffloadConfig.ENABLE_AIO: self._context_handle.enable_aio(),
161
229
  _OffloadConfig.AIO_BLOCK_SIZE: self._context_handle.aio_block_size(),
162
230
  _OffloadConfig.AIO_QUEUE_DEPTH: self._context_handle.aio_queue_depth(),
163
- _OffloadConfig.ENABLE_PINNED_MEM: self._context_handle.enable_pinned_mem()
231
+ _OffloadConfig.ENABLE_PINNED_MEM: self._context_handle.enable_pinned_mem(),
232
+ _OffloadConfig.AUTO_OFFLOAD: self._context_handle.auto_offload(),
233
+ _OffloadConfig.HOST_MEM_BLOCk_SIZE: self._context_handle.host_mem_block_size(),
234
+ _OffloadConfig.CPU_RATIO: self._context_handle.cpu_ratio(),
235
+ _OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio(),
236
+ _OffloadConfig.OFFLOAD_CHECKPOINT: self._context_handle.offload_checkpoint()
164
237
  }
165
238
  return offload_config
166
239
 
@@ -183,3 +256,20 @@ def _set_offload_context(offload_config):
183
256
 
184
257
  def _get_offload_context():
185
258
  return offload_context().offload_config()
259
+
260
+
261
+ _set_offload_context_func_map = {
262
+ _OffloadConfig.OFFLOAD_PARAM: offload_context().set_offload_param,
263
+ _OffloadConfig.OFFLOAD_PATH: offload_context().set_offload_path,
264
+ _OffloadConfig.OFFLOAD_CPU_SIZE: offload_context().set_offload_cpu_size,
265
+ _OffloadConfig.OFFLOAD_DISK_SIZE: offload_context().set_offload_disk_size,
266
+ _OffloadConfig.ENABLE_AIO: offload_context().set_enable_aio,
267
+ _OffloadConfig.AIO_BLOCK_SIZE: offload_context().set_aio_block_size,
268
+ _OffloadConfig.AIO_QUEUE_DEPTH: offload_context().set_aio_queue_depth,
269
+ _OffloadConfig.ENABLE_PINNED_MEM: offload_context().set_enable_pinned_mem,
270
+ _OffloadConfig.AUTO_OFFLOAD: offload_context().set_auto_offload,
271
+ _OffloadConfig.HOST_MEM_BLOCk_SIZE: offload_context().set_host_mem_block_size,
272
+ _OffloadConfig.CPU_RATIO: offload_context().set_cpu_ratio,
273
+ _OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio,
274
+ _OffloadConfig.OFFLOAD_CHECKPOINT: offload_context().set_offload_checkpoint
275
+ }
@@ -330,8 +330,8 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
330
330
  device_list = list(range(0, np.prod(from_tensor_layout[0])))
331
331
  param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id)
332
332
  param_rank_list_new = [rank % from_device_num for rank in param_rank_list]
333
- param_rank_list_new = set(param_rank_list_new)
334
- result_list.update(param_rank_list_new)
333
+ param_rank_set_new = set(param_rank_list_new)
334
+ result_list.update(param_rank_set_new)
335
335
  return list(result_list)
336
336
 
337
337
 
@@ -114,10 +114,10 @@ def _set_ps_context(**kwargs):
114
114
  Args:
115
115
  enable_ps (bool): Whether to enable parameter server training mode.
116
116
  Only after enable_ps is set True, the environment variables will be effective.
117
- Default: False.
117
+ Default: ``False``.
118
118
  config_file_path (string): Configuration file path used by recovery. Default: ''.
119
119
  scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
120
- enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: False.
120
+ enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``.
121
121
  client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''.
122
122
  server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ''.
123
123
 
@@ -180,8 +180,8 @@ def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size,
180
180
  ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size, param_key)
181
181
 
182
182
 
183
- def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
184
- ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
183
+ def _reinsert_hash_table_size(new_name, cur_name):
184
+ ps_context().reinsert_hash_table_size(new_name, cur_name)
185
185
 
186
186
 
187
187
  def _insert_accumu_init_info(name, init_val):
@@ -228,3 +228,15 @@ def _enable_distributed_mindrt():
228
228
  This method is used to distinguish from old distributed training mode.
229
229
  '''
230
230
  return ps_context().enable_distributed_mindrt()
231
+
232
+
233
+ def _set_checkpoint_load_status(status):
234
+ return ps_context().set_checkpoint_load_status(status)
235
+
236
+
237
+ def _store_warm_up_ptr_by_tensor(param_key, tensor):
238
+ return ps_context().store_warm_up_ptr_by_tensor(param_key, tensor)
239
+
240
+
241
+ def _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor):
242
+ return ps_context().store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
@@ -65,7 +65,8 @@ def _set_recovery_context(**kwargs):
65
65
 
66
66
  Args:
67
67
  ckpt_path (string): Set the recovery path used to save checkpoint. Default: ''.
68
- need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery. Default: False.
68
+ need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery.
69
+ Default: ``False``.
69
70
 
70
71
  Raises:
71
72
  ValueError: If input key is not the attribute in recovery context.