mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,8 @@ from PIL import Image
23
23
  import mindspore
24
24
  import mindspore._c_dataengine as cde
25
25
 
26
- if Image.__version__ >= "9.1.0":
26
+ # The following constants have been deprecated by Pillow since version 9.1.0
27
+ if int(Image.__version__.split(".")[0]) > 9 or Image.__version__ >= "9.1.0":
27
28
  FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT
28
29
  FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM
29
30
  PERSPECTIVE = Image.Transform.PERSPECTIVE
@@ -47,14 +48,14 @@ class AutoAugmentPolicy(str, Enum):
47
48
  """
48
49
  AutoAugment policy for different datasets.
49
50
 
50
- Possible enumeration values are: AutoAugmentPolicy.IMAGENET, AutoAugmentPolicy.CIFAR10,
51
+ Possible enumeration values are: ``AutoAugmentPolicy.IMAGENET``, ``AutoAugmentPolicy.CIFAR10``,
51
52
  AutoAugmentPolicy.SVHN.
52
53
 
53
54
  Each policy contains 25 pairs of augmentation operations. When using AutoAugment, each image is randomly
54
55
  transformed with one of these operation pairs. Each pair has 2 different operations. The following shows
55
56
  all of these augmentation operations, including operation names with their probabilities and random params.
56
57
 
57
- - AutoAugmentPolicy.IMAGENET: dataset auto augment policy for ImageNet.
58
+ - ``AutoAugmentPolicy.IMAGENET``: dataset auto augment policy for ImageNet.
58
59
 
59
60
  .. code-block::
60
61
 
@@ -73,7 +74,7 @@ class AutoAugmentPolicy(str, Enum):
73
74
  (("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
74
75
  (("Equalize", 0.8, None), ("Equalize", 0.6, None))]
75
76
 
76
- - AutoAugmentPolicy.CIFAR10: dataset auto augment policy for Cifar10.
77
+ - ``AutoAugmentPolicy.CIFAR10``: dataset auto augment policy for Cifar10.
77
78
 
78
79
  .. code-block::
79
80
 
@@ -94,7 +95,7 @@ class AutoAugmentPolicy(str, Enum):
94
95
  (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
95
96
  (("Equalize", 0.2, None), ("AutoContrast", 0.6, None))]
96
97
 
97
- - AutoAugmentPolicy.SVHN: dataset auto augment policy for SVHN.
98
+ - ``AutoAugmentPolicy.SVHN``: dataset auto augment policy for SVHN.
98
99
 
99
100
  .. code-block::
100
101
 
@@ -133,13 +134,13 @@ class Border(str, Enum):
133
134
  """
134
135
  Padding Mode, Border Type.
135
136
 
136
- Possible enumeration values are: Border.CONSTANT, Border.EDGE, Border.REFLECT, Border.SYMMETRIC.
137
+ Possible enumeration values are: ``Border.CONSTANT``, ``Border.EDGE``, ``Border.REFLECT``, ``Border.SYMMETRIC``.
137
138
 
138
- - Border.CONSTANT: means it fills the border with constant values.
139
- - Border.EDGE: means it pads with the last value on the edge.
140
- - Border.REFLECT: means it reflects the values on the edge omitting the last value of edge.
139
+ - ``Border.CONSTANT`` : means it fills the border with constant values.
140
+ - ``Border.EDGE`` : means it pads with the last value on the edge.
141
+ - ``Border.REFLECT`` : means it reflects the values on the edge omitting the last value of edge.
141
142
  For example, padding [1,2,3,4] with 2 elements on both sides will result in [3,2,1,2,3,4,3,2].
142
- - Border.SYMMETRIC: means it reflects the values on the edge repeating the last value of edge.
143
+ - ``Border.SYMMETRIC`` : means it reflects the values on the edge repeating the last value of edge.
143
144
  For example, padding [1,2,3,4] with 2 elements on both sides will result in [2,1,1,2,3,4,4,3].
144
145
 
145
146
  Note:
@@ -256,10 +257,10 @@ class ImageBatchFormat(IntEnum):
256
257
  """
257
258
  Data Format of images after batch operation.
258
259
 
259
- Possible enumeration values are: ImageBatchFormat.NHWC, ImageBatchFormat.NCHW.
260
+ Possible enumeration values are: ``ImageBatchFormat.NHWC``, ``ImageBatchFormat.NCHW``.
260
261
 
261
- - ImageBatchFormat.NHWC: in orders like, batch N, height H, width W, channels C to store the data.
262
- - ImageBatchFormat.NCHW: in orders like, batch N, channels C, height H, width W to store the data.
262
+ - ``ImageBatchFormat.NHWC``: in orders like, batch N, height H, width W, channels C to store the data.
263
+ - ``ImageBatchFormat.NCHW``: in orders like, batch N, channels C, height H, width W to store the data.
263
264
  """
264
265
  NHWC = 0
265
266
  NCHW = 1
@@ -279,11 +280,11 @@ class ImageReadMode(IntEnum):
279
280
  """
280
281
  The read mode used for the image file.
281
282
 
282
- Possible enumeration values are: ImageReadMode.UNCHANGED, ImageReadMode.GRAYSCALE, ImageReadMode.COLOR.
283
+ Possible enumeration values are: ``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAYSCALE``, ``ImageReadMode.COLOR``.
283
284
 
284
- - ImageReadMode.UNCHANGED: remain the output in the original format.
285
- - ImageReadMode.GRAYSCALE: convert the output into one channel grayscale data.
286
- - ImageReadMode.COLOR: convert the output into three channels RGB color data.
285
+ - ``ImageReadMode.UNCHANGED``: remain the output in the original format.
286
+ - ``ImageReadMode.GRAYSCALE``: convert the output into one channel grayscale data.
287
+ - ``ImageReadMode.COLOR``: convert the output into three channels RGB color data.
287
288
  """
288
289
  UNCHANGED = 0
289
290
  GRAYSCALE = 1
@@ -302,20 +303,19 @@ class ImageReadMode(IntEnum):
302
303
 
303
304
  class Inter(IntEnum):
304
305
  """
305
- Interpolation Modes.
306
+ Interpolation methods.
306
307
 
307
- Possible enumeration values are: Inter.NEAREST, Inter.ANTIALIAS, Inter.LINEAR, Inter.BILINEAR, Inter.CUBIC,
308
- Inter.BICUBIC, Inter.AREA, Inter.PILCUBIC.
308
+ Available values are as follows:
309
309
 
310
- - Inter.NEAREST: means interpolation method is nearest-neighbor interpolation.
311
- - Inter.ANTIALIAS: means the interpolation method is antialias interpolation.
312
- - Inter.LINEAR: means interpolation method is bilinear interpolation, here is the same as Inter.BILINEAR.
313
- - Inter.BILINEAR: means interpolation method is bilinear interpolation.
314
- - Inter.CUBIC: means the interpolation method is bicubic interpolation, here is the same as Inter.BICUBIC.
315
- - Inter.BICUBIC: means the interpolation method is bicubic interpolation.
316
- - Inter.AREA: means interpolation method is pixel area interpolation.
317
- - Inter.PILCUBIC: means interpolation method is bicubic interpolation like implemented in pillow, input
318
- should be in 3 channels format.
310
+ - ``Inter.NEAREST`` : Nearest neighbor interpolation.
311
+ - ``Inter.ANTIALIAS`` : Antialias interpolation. Supported only when the input is PIL.Image.Image.
312
+ - ``Inter.LINEAR`` : Linear interpolation, the same as ``Inter.BILINEAR``.
313
+ - ``Inter.BILINEAR`` : Bilinear interpolation.
314
+ - ``Inter.CUBIC`` : Cubic interpolation, the same as ``Inter.BICUBIC``.
315
+ - ``Inter.BICUBIC`` : Bicubic interpolation.
316
+ - ``Inter.AREA`` : Pixel area interpolation. Supported only when the input is numpy.ndarray.
317
+ - ``Inter.PILCUBIC`` : Pillow implementation of bicubic interpolation. Supported only when the input
318
+ is numpy.ndarray.
319
319
  """
320
320
  NEAREST = 0
321
321
  ANTIALIAS = 1
@@ -354,10 +354,10 @@ class SliceMode(IntEnum):
354
354
  """
355
355
  Mode to Slice Tensor into multiple parts.
356
356
 
357
- Possible enumeration values are: SliceMode.PAD, SliceMode.DROP.
357
+ Possible enumeration values are: ``SliceMode.PAD``, ``SliceMode.DROP``.
358
358
 
359
- - SliceMode.PAD: pad some pixels before slice the Tensor if needed.
360
- - SliceMode.DROP: drop remainder pixels before slice the Tensor if needed.
359
+ - ``SliceMode.PAD``: pad some pixels before slice the Tensor if needed.
360
+ - ``SliceMode.DROP``: drop remainder pixels before slice the Tensor if needed.
361
361
  """
362
362
  PAD = 0
363
363
  DROP = 1
@@ -379,7 +379,7 @@ def encode_jpeg(image, quality=75):
379
379
 
380
380
  Args:
381
381
  image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
382
- quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: 75.
382
+ quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: ``75``.
383
383
 
384
384
  Returns:
385
385
  numpy.ndarray, one dimension uint8 data.
@@ -395,6 +395,7 @@ def encode_jpeg(image, quality=75):
395
395
  ``CPU``
396
396
 
397
397
  Examples:
398
+ >>> import mindspore.dataset.vision as vision
398
399
  >>> import numpy as np
399
400
  >>> # Generate a random image with height=120, width=340, channels=3
400
401
  >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
@@ -416,7 +417,8 @@ def encode_png(image, compression_level=6):
416
417
 
417
418
  Args:
418
419
  image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
419
- compression_level (int, optional): The compression_level for encoding, in range of [0, 9]. Default: 6.
420
+ compression_level (int, optional): The `compression_level` for encoding, in range of [0, 9].
421
+ Default: ``6``.
420
422
 
421
423
  Returns:
422
424
  numpy.ndarray, one dimension uint8 data.
@@ -432,6 +434,7 @@ def encode_png(image, compression_level=6):
432
434
  ``CPU``
433
435
 
434
436
  Examples:
437
+ >>> import mindspore.dataset.vision as vision
435
438
  >>> import numpy as np
436
439
  >>> # Generate a random image with height=120, width=340, channels=3
437
440
  >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
@@ -463,6 +466,9 @@ def get_image_num_channels(image):
463
466
  TypeError: If `image` is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>.
464
467
 
465
468
  Examples:
469
+ >>> import mindspore.dataset.vision as vision
470
+ >>> from PIL import Image
471
+ >>> image = Image.open("/path/to/image_file")
466
472
  >>> num_channels = vision.get_image_num_channels(image)
467
473
  """
468
474
 
@@ -493,6 +499,9 @@ def get_image_size(image):
493
499
  TypeError: If `image` is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>.
494
500
 
495
501
  Examples:
502
+ >>> import mindspore.dataset.vision as vision
503
+ >>> from PIL import Image
504
+ >>> image = Image.open("/path/to/image_file")
496
505
  >>> image_size = vision.get_image_size(image)
497
506
  """
498
507
 
@@ -538,6 +547,7 @@ def read_file(filename):
538
547
  ``CPU``
539
548
 
540
549
  Examples:
550
+ >>> import mindspore.dataset.vision as vision
541
551
  >>> output = vision.read_file("/path/to/file")
542
552
  """
543
553
  if isinstance(filename, str):
@@ -552,8 +562,9 @@ def read_image(filename, mode=ImageReadMode.UNCHANGED):
552
562
 
553
563
  Args:
554
564
  filename(str): The path to the image file to be read.
555
- mode(ImageReadMode, optional): The mode used for decoding the image. It can be any of
556
- [ImageReadMode.UNCHANGED, ImageReadMode.GRAYSCALE, IMageReadMode.COLOR]. Default: ImageReadMode.UNCHANGED.
565
+ mode(ImageReadMode, optional): The mode used for decoding the image. It can be
566
+ ``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAYSCALE``, ``IMageReadMode.COLOR``.
567
+ Default: ``ImageReadMode.UNCHANGED``.
557
568
 
558
569
  - ImageReadMode.UNCHANGED, remain the output in the original format.
559
570
 
@@ -573,6 +584,7 @@ def read_image(filename, mode=ImageReadMode.UNCHANGED):
573
584
  ``CPU``
574
585
 
575
586
  Examples:
587
+ >>> import mindspore.dataset.vision as vision
576
588
  >>> from mindspore.dataset.vision import ImageReadMode
577
589
  >>> output = vision.read_image("/path/to/image_file", ImageReadMode.UNCHANGED)
578
590
  """
@@ -602,6 +614,7 @@ def write_file(filename, data):
602
614
  ``CPU``
603
615
 
604
616
  Examples:
617
+ >>> import mindspore.dataset.vision as vision
605
618
  >>> import numpy as np
606
619
  >>> # Generate a random data with 1024 bytes
607
620
  >>> data = np.random.randint(256, size=(1024), dtype=np.uint8)
@@ -624,7 +637,7 @@ def write_jpeg(filename, image, quality=75):
624
637
  Args:
625
638
  filename (str): The path to the file to be written.
626
639
  image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
627
- quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: 75.
640
+ quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: ``75``.
628
641
 
629
642
  Raises:
630
643
  TypeError: If `filename` is not of type str.
@@ -639,6 +652,7 @@ def write_jpeg(filename, image, quality=75):
639
652
  ``CPU``
640
653
 
641
654
  Examples:
655
+ >>> import mindspore.dataset.vision as vision
642
656
  >>> import numpy as np
643
657
  >>> # Generate a random image with height=120, width=340, channels=3
644
658
  >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
@@ -663,7 +677,8 @@ def write_png(filename, image, compression_level=6):
663
677
  Args:
664
678
  filename (str): The path to the file to be written.
665
679
  image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
666
- compression_level (int, optional): Compression level for the resulting PNG file, in range of [0, 9]. Default: 6.
680
+ compression_level (int, optional): Compression level for the resulting PNG file, in range of [0, 9].
681
+ Default: ``6``.
667
682
 
668
683
  Raises:
669
684
  TypeError: If `filename` is not of type str.
@@ -678,6 +693,7 @@ def write_png(filename, image, compression_level=6):
678
693
  ``CPU``
679
694
 
680
695
  Examples:
696
+ >>> import mindspore.dataset.vision as vision
681
697
  >>> import numpy as np
682
698
  >>> # Generate a random image with height=120, width=340, channels=3
683
699
  >>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
@@ -24,7 +24,7 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
24
24
  check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
25
25
  parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, UINT8_MIN, check_value_normalize_std, \
26
26
  check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
27
- check_pos_int32, check_int32, check_tensor_op, deprecator_factory
27
+ check_pos_int32, check_int32, check_tensor_op, deprecator_factory, check_valid_str
28
28
  from mindspore.dataset.transforms.validators import check_transform_op_type
29
29
  from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
30
30
 
@@ -339,6 +339,16 @@ def check_resize_interpolation(method):
339
339
 
340
340
  return new_method
341
341
 
342
+ def check_device_target(method):
343
+ """A wrapper that wraps a parameter checker"""
344
+
345
+ @wraps(method)
346
+ def new_method(self, *args, **kwargs):
347
+ [device_target], _ = parse_user_args(method, *args, **kwargs)
348
+ check_valid_str(device_target, ["CPU", "Ascend"], "device_target")
349
+ return method(self, *args, **kwargs)
350
+ return new_method
351
+
342
352
 
343
353
  def check_resized_crop(method):
344
354
  """A wrapper that wraps a parameter checker around the original function(ResizedCrop operation)."""
@@ -715,7 +725,7 @@ def check_pad_to_size(method):
715
725
  else:
716
726
  if len(offset) not in [0, 2]:
717
727
  raise ValueError("The offset must be empty or a sequence of length 2.")
718
- for i, value in enumerate(offset):
728
+ for i, _ in enumerate(offset):
719
729
  check_non_negative_int32(offset[i], "offset{0}".format(i))
720
730
 
721
731
  check_fill_value(fill_value)
@@ -17,6 +17,7 @@ from __future__ import absolute_import
17
17
 
18
18
  __all__ = ['MapParameter']
19
19
 
20
+ import os
20
21
  import sys
21
22
  from copy import copy
22
23
  import numbers
@@ -46,7 +47,7 @@ class MapParameter(Parameter):
46
47
  default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: 'normal'.
47
48
  permit_filter_value (numbers.Number): The permit filter value number. Default: 1.
48
49
  evict_filter_value (numbers.Number): The evict filter value number. Default: MAX_SIZE.
49
- name (str): Name of the map parameter. Default: None.
50
+ name (str): Name of the map parameter. Default: ``None``.
50
51
  requires_grad (bool): True if the parameter requires gradient. Default: True.
51
52
 
52
53
 
@@ -256,15 +257,28 @@ class MapParameter(Parameter):
256
257
 
257
258
  Args:
258
259
  incremental (bool): False for full export, otherwise for incremental export. Default: False.
259
- When exporting data incrementally, the value_array does not contain erased data, so the length of the
260
- key_array and the length of the value_array may be inconsistent.The length of the key_array and the length
261
- of the status_array are consistent.
260
+ When exporting data incrementally, the value_array does not contain unchanged data.The length
261
+ of the key_array and the length of the status_array are consistent.
262
262
 
263
263
  Returns:
264
264
  Tuple(key_array, value_array, status_array), The exported data as a tuple.
265
265
  """
266
266
  return self._map_tensor.export_data(incremental)
267
267
 
268
+ def export_bytes(self, incremental=False):
269
+ """
270
+ Export bytes from this map parameter.
271
+
272
+ Args:
273
+ incremental (bool): False for full export, otherwise for incremental export. Default: False.
274
+ When exporting data incrementally, the value_array does not contain unchanged data. The length
275
+ of the key_array and the length of the status_array are consistent.
276
+
277
+ Returns:
278
+ Tuple(bytes, bytes, bytes), The exported bytes as a tuple.
279
+ """
280
+ return self._map_tensor.export_bytes(incremental)
281
+
268
282
  def import_data(self, data):
269
283
  """
270
284
  Import this map parameter from exported data.
@@ -273,3 +287,23 @@ class MapParameter(Parameter):
273
287
  data (Tuple): The data tuple with key_array, value_array and status_array.
274
288
  """
275
289
  self._map_tensor.import_data(data)
290
+
291
+ def export_slice_data(self, incremental=False):
292
+ """
293
+ Export a slice data from this map parameter.
294
+ When MapParameter occupies a large memory, only one slice
295
+ of MapParameter is exported at a time (the default slice size is 1GB).
296
+
297
+ Args:
298
+ incremental (bool): False for full export, otherwise for incremental export. Default: False.
299
+ When exporting data incrementally, the value_array does not contain unchanged data.The length
300
+ of the key_array and the length of the status_array are consistent.
301
+
302
+ Returns:
303
+ Tuple(key_array, value_array, status_array, last_slice), The exported data as a tuple, and
304
+ the last_slice is bool variable and means whether finish export.
305
+ """
306
+ enable_persistent = "MS_EMBEDDING_REMOTE_CACHE_MEMORY_SIZE" in os.environ
307
+ if not enable_persistent:
308
+ return self._map_tensor.export_slice_data(incremental)
309
+ return self._map_tensor.export_persistent_slice_data(self.key, incremental)
@@ -1,4 +1,4 @@
1
- # Copyright 2019 Huawei Technologies Co., Ltd
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,10 +11,20 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- # ==============================================================================
14
+ # ============================================================================
15
+ """
16
+ Optimizer.
15
17
 
16
- """Preprocess of dataset.
18
+ Provide common optimizers for training, such as SGD, ADAM, Momentum.
19
+ The optimizer is used to calculate and update the gradients.
17
20
  """
18
21
  from __future__ import absolute_import
19
22
 
20
- from mindspore.dataset.datapreprocess.preprocess_imagenet_validate_dataset import *
23
+ from mindspore.experimental.optim.optimizer import Optimizer
24
+ from mindspore.experimental.optim.adamw import AdamW
25
+ from mindspore.experimental.optim.sgd import SGD
26
+ from mindspore.experimental.optim.adam import Adam
27
+ from mindspore.experimental.optim import lr_scheduler
28
+
29
+
30
+ __all__ = ['Optimizer', 'AdamW', 'SGD', 'Adam']
@@ -0,0 +1,192 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """adam"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.ops import functional as F, composite as C, operations as P
19
+ from mindspore.common.parameter import Parameter
20
+ from mindspore.common.tensor import Tensor
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.experimental.optim.optimizer import Optimizer
23
+
24
+ _adam_opt = C.MultitypeFuncGraph("adam_opt")
25
+
26
+
27
+ @_adam_opt.register("Function", "Tensor", "Tensor", "Float", "Float", "Float", "Tensor",
28
+ "Tensor", "Tensor", "Tensor", "Tensor")
29
+ def _run_adam_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
30
+ """Apply adam optimizer to the weight parameter."""
31
+ success = True
32
+ success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient))
33
+ return success
34
+
35
+
36
+ @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
37
+ def _run_adam_with_amsgrad_opt(opt, beta1_power, beta2_power, lr, gradient, param, moment1, moment2, vhat):
38
+ """Apply adam optimizer to the weight parameter with amsgrad."""
39
+ success = True
40
+ success = F.depend(success, opt(param, moment1, moment2, vhat, beta1_power, beta2_power, lr, gradient))
41
+ return success
42
+
43
+
44
+ class Adam(Optimizer):
45
+ r"""
46
+ Implements Adam algorithm..
47
+
48
+ The updating formulas are as follows:
49
+
50
+ .. math::
51
+ \begin{aligned}
52
+ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
53
+ \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
54
+ &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
55
+ \:\textit{maximize} \\
56
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
57
+ v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
58
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
59
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
60
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
61
+ &\hspace{5mm}\textbf{else} \\
62
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
63
+ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
64
+ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
65
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
66
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
67
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
68
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
69
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
70
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
71
+ \widehat{v_t}) \\
72
+ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
73
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
74
+ &\hspace{5mm}\textbf{else} \\
75
+ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
76
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
77
+ &\bf{return} \: \theta_t \\[-1.ex]
78
+ \end{aligned}
79
+
80
+ .. warning::
81
+ This is an experimental optimizer API that is subject to change.
82
+ This module must be used with lr scheduler module in `LRScheduler Class
83
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.nn.html#learningrateschedule-class>`_ .
84
+
85
+ Args:
86
+ params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
87
+ parameter groups
88
+ lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
89
+ betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
90
+ Default: ``(0.9, 0.999)``.
91
+ eps (float, optional): term added to the denominator to improve
92
+ numerical stability. Default: ``1e-8``.
93
+ weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
94
+ amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
95
+
96
+ Keyword Args:
97
+ maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
98
+ Default: ``False``.
99
+
100
+ Inputs:
101
+ - **gradients** (tuple[Tensor]) - The gradients of `params`.
102
+
103
+ Raises:
104
+ ValueError: If the `lr` is not int, float or Tensor.
105
+ ValueError: If the `lr` is less than 0.
106
+ ValueError: If the `eps` is less than 0.0.
107
+ ValueError: If the `betas` not in the range of 0-1.
108
+ ValueError: If the `weight_decay` is less than 0.
109
+
110
+ Supported Platforms:
111
+ ``Ascend`` ``GPU`` ``CPU``
112
+
113
+ Examples:
114
+ >>> import mindspore
115
+ >>> from mindspore import nn
116
+ >>> from mindspore.experimental import optim
117
+ >>> # Define the network structure of LeNet5. Refer to
118
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
119
+ >>> net = LeNet5()
120
+ >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
121
+ >>> optimizer = optim.Adam(net.trainable_params(), lr=0.1)
122
+ >>> def forward_fn(data, label):
123
+ ... logits = net(data)
124
+ ... loss = loss_fn(logits, label)
125
+ ... return loss, logits
126
+ >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
127
+ >>> def train_step(data, label):
128
+ ... (loss, _), grads = grad_fn(data, label)
129
+ ... optimizer(grads)
130
+ ... return loss
131
+ """
132
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
133
+ weight_decay=0, amsgrad=False, *, maximize=False):
134
+ if lr < 0.0:
135
+ raise ValueError("Invalid learning rate: {}".format(lr))
136
+ if eps < 0.0:
137
+ raise ValueError("Invalid epsilon value: {}".format(eps))
138
+ if not 0.0 <= betas[0] < 1.0:
139
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
140
+ if not 0.0 <= betas[1] < 1.0:
141
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
142
+ if weight_decay < 0.0:
143
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
144
+
145
+ defaults = dict(lr=lr, betas=betas, eps=eps,
146
+ weight_decay=weight_decay, amsgrad=amsgrad,
147
+ maximize=maximize)
148
+ super(Adam, self).__init__(params, defaults)
149
+
150
+ self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
151
+ self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
152
+ self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
153
+ self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
154
+ self.increase_tensor = Tensor(1, mstype.int32)
155
+ self.assignadd = P.AssignAdd()
156
+ self.op_add = P.AddN()
157
+ self.op_mul = P.Mul()
158
+ self.op_pow = P.Pow()
159
+ self.adam_opt = P.Adam(False, False)
160
+ self.op_cast = P.Cast()
161
+
162
+ def construct(self, gradients):
163
+ self.assignadd(self.state_step, self.increase_tensor)
164
+ for group_id, group in enumerate(self.param_groups):
165
+ start_id = self.group_start_id[group_id]
166
+ end_id = self.group_start_id[group_id+1]
167
+
168
+ lr = group.get("lr")
169
+ weight_decay = group.get("weight_decay")
170
+ beta1, beta2 = group.get("betas")
171
+ maximize = group.get("maximize")
172
+ eps = group.get("eps")
173
+
174
+ beta1_power = self.op_pow(beta1, self.state_step)
175
+ beta2_power = self.op_pow(beta2, self.state_step)
176
+ adam_with_amsgrad_opt = P.ApplyAdamWithAmsgrad(beta1, beta2, eps, False)
177
+ params = self.parameters[start_id: end_id]
178
+ grads = gradients[start_id: end_id]
179
+ grads = grads if not maximize else -grads
180
+ grads = self._decay_weight(weight_decay, params, grads)
181
+ if isinstance(lr, float):
182
+ lr = self.op_cast(group.get("lr"), mstype.float32)
183
+ if group.get("amsgrad"):
184
+ self.hyper_map(F.partial(_adam_opt, adam_with_amsgrad_opt, beta1_power, beta2_power, lr),
185
+ grads, params,
186
+ self.exp_avg[start_id: end_id], self.exp_avg_sq[start_id: end_id],
187
+ self.max_exp_avg_sq[start_id: end_id])
188
+ else:
189
+ self.hyper_map(F.partial(_adam_opt, self.adam_opt, beta1_power, beta2_power, beta1, beta2, eps, lr),
190
+ grads, params,
191
+ self.exp_avg[start_id: end_id], self.exp_avg_sq[start_id: end_id])
192
+ return True