mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -83,7 +83,7 @@ class Rprop(Optimizer):
83
83
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
84
84
  one group of `params`.
85
85
 
86
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: 0.1.
86
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: ``0.1`` .
87
87
 
88
88
  - float: The fixed learning rate value. Must be equal to or greater than 0.
89
89
 
@@ -98,10 +98,10 @@ class Rprop(Optimizer):
98
98
  LearningRateSchedule with step as the input to get the learning rate of current step.
99
99
 
100
100
  etas (tuple[float, float]): The factor of multiplicative increasing or
101
- descreasing(etaminus, etaplus). Default: (0.5, 1.2).
101
+ descreasing(etaminus, etaplus). Default: ``(0.5, 1.2)`` .
102
102
  step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
103
- Default: (1e-6, 50.).
104
- weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
103
+ Default: ``(1e-6, 50.)`` .
104
+ weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
105
105
 
106
106
  - float: The fixed weight decay value. Must be equal to or greater than 0.
107
107
 
@@ -134,7 +134,9 @@ class Rprop(Optimizer):
134
134
  >>> import mindspore as ms
135
135
  >>> from mindspore import nn
136
136
  >>>
137
- >>> net = Net()
137
+ >>> # Define the network structure of LeNet5. Refer to
138
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
139
+ >>> net = LeNet5()
138
140
  >>> #1) All parameters use the same learning rate and weight decay
139
141
  >>> optim = nn.Rprop(params=net.trainable_params())
140
142
  >>>
@@ -152,7 +154,7 @@ class Rprop(Optimizer):
152
154
  >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
153
155
  >>>
154
156
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
155
- >>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
157
+ >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
156
158
  """
157
159
 
158
160
  @opt_init_args_register
@@ -187,8 +189,8 @@ class Rprop(Optimizer):
187
189
  self.prev = self._parameters.clone(prefix="prev", init='zeros')
188
190
  self.step_size = self._parameters.clone(prefix="step_size", init='zeros')
189
191
 
190
- self.fill = P.Fill()
191
192
  self.sign = P.Sign()
193
+ self.fill = P.FillV2()
192
194
  self.assign = P.Assign()
193
195
  self.assignadd = P.AssignAdd()
194
196
  self.cast = P.Cast()
@@ -202,8 +204,7 @@ class Rprop(Optimizer):
202
204
  gradients = self.gradients_centralization(gradients)
203
205
  gradients = self.scale_grad(gradients)
204
206
  lrs = self.get_lr()
205
- if not self._is_dynamic_lr_or_weight_decay():
206
- self.assignadd(self.global_step, self.global_step_increase_tensor)
207
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
207
208
  success = True
208
209
 
209
210
  for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self._parameters,
@@ -219,14 +220,26 @@ class Rprop(Optimizer):
219
220
  param_fp32 = self.cast(param, mstype.float32)
220
221
 
221
222
  sign = self.sign(gradient_fp32 * prev)
222
- sign = self.select(sign > 0, self.fill(mstype.float32, sign.shape, self.etaplus), sign)
223
- sign = self.select(sign < 0, self.fill(mstype.float32, sign.shape, self.etaminus), sign)
224
- sign = self.select(sign == 0, self.fill(mstype.float32, sign.shape, 1.), sign)
225
-
226
- step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign, self.step_size_min, self.step_size_max)
227
-
228
- gradient_update = self.select(sign == self.etaminus, self.fill(mstype.float32, sign.shape, 0.),
229
- gradient_fp32)
223
+ sign = self.select(
224
+ sign > 0,
225
+ self.fill(sign.shape, self.cast(self.etaplus, mstype.float32)),
226
+ sign)
227
+ sign = self.select(
228
+ sign < 0,
229
+ self.fill(sign.shape, self.cast(self.etaminus,
230
+ mstype.float32)), sign)
231
+ sign = self.select(
232
+ sign == 0, self.fill(sign.shape,
233
+ self.cast(1., mstype.float32)), sign)
234
+
235
+ step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign,
236
+ self.step_size_min,
237
+ self.step_size_max)
238
+
239
+ gradient_update = self.select(
240
+ sign == self.etaminus,
241
+ self.fill(sign.shape, self.cast(0., mstype.float32)),
242
+ gradient_fp32)
230
243
  next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32
231
244
 
232
245
  self.assign(param, self.cast(next_param, param.dtype))
mindspore/nn/optim/sgd.py CHANGED
@@ -44,17 +44,17 @@ class SGD(Optimizer):
44
44
  momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
45
45
 
46
46
  .. math::
47
- v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
47
+ v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
48
48
 
49
49
  If nesterov is True:
50
50
 
51
51
  .. math::
52
- p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
52
+ p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
53
53
 
54
54
  If nesterov is False:
55
55
 
56
56
  .. math::
57
- p_{t+1} = p_{t} - lr \ast v_{t+1}
57
+ p_{t+1} = p_{t} - lr \ast v_{t+1}
58
58
 
59
59
  To be noticed, for the first step, :math:`v_{t+1} = gradient`.
60
60
 
@@ -90,7 +90,7 @@ class SGD(Optimizer):
90
90
  If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
91
91
  one group of `params`.
92
92
 
93
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 0.1.
93
+ learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
94
94
 
95
95
  - float: The fixed learning rate value. Must be equal to or greater than 0.
96
96
 
@@ -104,22 +104,22 @@ class SGD(Optimizer):
104
104
  - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
105
105
  LearningRateSchedule with step as the input to get the learning rate of current step.
106
106
 
107
- momentum (float): A floating point value the momentum. must be at least 0.0. Default: 0.0.
108
- dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: 0.0.
109
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
107
+ momentum (float): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
108
+ dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: ``0.0`` .
109
+ weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: ``0.0`` .
110
110
  nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
111
- and dampening must be equal to 0.0. Default: False.
111
+ and dampening must be equal to 0.0. Default: ``False`` .
112
112
  loss_scale (float): A floating point value for the loss scale, which must be larger than 0.0. In general, use
113
113
  the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
114
- `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
114
+ `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
115
115
  `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
116
- Default: 1.0.
116
+ Default: ``1.0`` .
117
117
 
118
118
  Inputs:
119
119
  - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
120
120
 
121
121
  Outputs:
122
- Tensor[bool], the value is True.
122
+ Tensor[bool], the value is ``True`` .
123
123
 
124
124
  Raises:
125
125
  ValueError: If the momentum, dampening or weight_decay value is less than 0.0.
@@ -131,7 +131,9 @@ class SGD(Optimizer):
131
131
  >>> import mindspore as ms
132
132
  >>> from mindspore import nn
133
133
  >>>
134
- >>> net = Net()
134
+ >>> # Define the network structure of LeNet5. Refer to
135
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
136
+ >>> net = LeNet5()
135
137
  >>> #1) All parameters use the same learning rate and weight decay
136
138
  >>> optim = nn.SGD(params=net.trainable_params())
137
139
  >>>
@@ -149,7 +151,7 @@ class SGD(Optimizer):
149
151
  >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
150
152
  >>>
151
153
  >>> loss = nn.SoftmaxCrossEntropyWithLogits()
152
- >>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
154
+ >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
153
155
  """
154
156
 
155
157
  @opt_init_args_register
@@ -161,29 +163,29 @@ class SGD(Optimizer):
161
163
  if isinstance(momentum, int):
162
164
  momentum = float(momentum)
163
165
  if not isinstance(momentum, float):
164
- raise TypeError("For 'SGD', the argument 'momentum' must be float type, "
165
- "but got {}.".format(type(momentum)))
166
+ raise TypeError(f"For 'SGD', the argument 'momentum' must be float type, "
167
+ f"but got {type(momentum)}.")
166
168
 
167
169
  if isinstance(momentum, float) and momentum < 0.0:
168
- raise ValueError("For 'SGD', the argument 'momentum' must be at least 0.0, "
169
- "but got {}.".format(momentum))
170
+ raise ValueError(f"For 'SGD', the argument 'momentum' must be at least 0.0, "
171
+ f"but got {momentum}.")
170
172
 
171
173
  if isinstance(dampening, int):
172
174
  dampening = float(dampening)
173
175
  if not isinstance(dampening, float):
174
- raise TypeError("For 'SGD', the argument 'dampening' must be float type, "
175
- "but got {}.".format(type(dampening)))
176
+ raise TypeError(f"For 'SGD', the argument 'dampening' must be float type, "
177
+ f"but got {type(dampening)}.")
176
178
 
177
179
  if dampening < 0.0:
178
- raise ValueError("For 'SGD', the argument 'dampening' must be at least 0.0, "
179
- "but got 'dampening' {}".format(dampening))
180
+ raise ValueError(f"For 'SGD', the argument 'dampening' must be at least 0.0, "
181
+ f"but got 'dampening' {dampening}")
180
182
  self.dampening = dampening
181
183
 
182
184
  validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
183
185
 
184
186
  if nesterov and (momentum <= 0.0 or dampening != 0.0):
185
- raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
186
- "equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
187
+ raise ValueError(f"For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
188
+ f"equal to 0.0, but got 'momentum' {momentum}, 'dampening' {dampening}.")
187
189
  self.nesterov = nesterov
188
190
 
189
191
  if self.dynamic_weight_decay:
@@ -196,9 +198,23 @@ class SGD(Optimizer):
196
198
  self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
197
199
 
198
200
  self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
201
+
202
+ if not momentum > 0.0:
203
+ enable_cache_param_list = []
204
+ for param in self._parameters:
205
+ if param.cache_enable:
206
+ enable_cache_param_list.append(param)
207
+ param.cache_enable = False
208
+
199
209
  self.accum = self._parameters.clone(prefix="accum", init='zeros')
200
210
  self.stat = self._parameters.clone(prefix="stat", init='ones')
201
211
 
212
+
213
+ if not momentum > 0.0:
214
+ for param in enable_cache_param_list:
215
+ param.cache_enable = True
216
+
217
+
202
218
  @jit
203
219
  def construct(self, gradients):
204
220
  params = self._parameters
@@ -208,6 +224,7 @@ class SGD(Optimizer):
208
224
  gradients = self.gradients_centralization(gradients)
209
225
  gradients = self.scale_grad(gradients)
210
226
  lr = self.get_lr()
227
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
211
228
  if self.is_group_lr:
212
229
  success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum),
213
230
  lr, gradients, params, accum, stat, self.opt)
@@ -266,10 +266,10 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
266
266
  \otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
267
267
  \end{array}
268
268
 
269
- :math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer.
270
- :math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer.
269
+ :math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer.
270
+ :math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer.
271
271
  :math:`I` represents the identity matrix.
272
- :math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the i-th layer.
272
+ :math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer.
273
273
  :math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'.
274
274
 
275
275
  Note:
@@ -290,14 +290,15 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
290
290
 
291
291
  momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
292
292
 
293
- weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
293
+ weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0.
294
+ Default: ``0.0`` .
294
295
 
295
296
  loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
296
- default value. Default: 1.0.
297
+ default value. Default: ``1.0`` .
297
298
 
298
- batch_size (int): The size of a batch. Default: 32
299
+ batch_size (int): The size of a batch. Default: ``32`` .
299
300
 
300
- use_nesterov (bool): Enable Nesterov momentum. Default: False.
301
+ use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .
301
302
 
302
303
  decay_filter (function): A function to determine which layers the weight decay applied to. And it
303
304
  only works when the weight_decay > 0. Default: lambda x: x.name not in []
@@ -305,13 +306,13 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
305
306
  split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
306
307
  computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
307
308
  to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
308
- is 27~53. Default: None
309
+ is 27~53. Default: ``None`` .
309
310
 
310
- enable_clip_grad (bool): Whether to clip the gradients. Default: False
311
+ enable_clip_grad (bool): Whether to clip the gradients. Default: ``False`` .
311
312
 
312
313
  frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N
313
314
  (N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps,
314
- and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: 100.
315
+ and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: ``100`` .
315
316
 
316
317
  Inputs:
317
318
  - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -333,21 +334,18 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
333
334
  ``Ascend`` ``GPU``
334
335
 
335
336
  Examples:
336
- .. note::
337
- Before running the following example, you need to customize the network Net and
338
- dataset preparation function create_dataset. Refer to
339
- `Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
340
- and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
341
-
342
337
  >>> import mindspore as ms
343
- >>> from mindspore.nn import thor
344
338
  >>> from mindspore import nn
345
339
  >>> from mindspore import Tensor
346
340
  >>>
347
- >>> net = Net()
341
+ >>> # Define the network structure of LeNet5. Refer to
342
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
343
+ >>> net = LeNet5()
344
+ >>> # Create the dataset taking MNIST as an example. Refer to
345
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
348
346
  >>> dataset = create_dataset()
349
347
  >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
350
- >>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
348
+ >>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
351
349
  >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
352
350
  >>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False)
353
351
  >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
@@ -355,8 +353,6 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
355
353
  >>> model = ms.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
356
354
  ... loss_scale_manager=loss_scale, metrics={'acc'},
357
355
  ... amp_level="O2", keep_batchnorm_fp32=False)
358
- >>> loss_cb = ms.LossMonitor()
359
- >>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True)
360
356
 
361
357
  """
362
358
  context.set_context(max_call_depth=10000)
@@ -428,7 +424,7 @@ class ThorGpu(Optimizer):
428
424
  self.matmul = P.MatMul()
429
425
  self.assign = P.Assign()
430
426
  self.mul = P.Mul()
431
- self.gather = P.GatherV2()
427
+ self.gather = P.Gather()
432
428
  self.one = Tensor(1, mstype.int32)
433
429
  self.feature_map = Tensor(1.0, mstype.float32)
434
430
  self.axis = 0
@@ -657,6 +653,7 @@ class ThorGpu(Optimizer):
657
653
  gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
658
654
  gradients = clip_gradient(self.enable_clip_grad, gradients)
659
655
  lr = self.get_lr()
656
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
660
657
  success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
661
658
  return success
662
659
 
@@ -681,7 +678,7 @@ class ThorAscend(Optimizer):
681
678
  self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
682
679
  logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
683
680
  self._define_ascend_operator()
684
- self.C0 = 16
681
+ self.c0 = 16
685
682
  self.device_shape_pad_flag = ()
686
683
  self.diag_block_dim = 128
687
684
  self.matrix_a = ()
@@ -743,7 +740,7 @@ class ThorAscend(Optimizer):
743
740
  self.log = P.Log()
744
741
  self.exp = P.Exp()
745
742
  self.sqrt = P.Sqrt()
746
- self.gather = P.GatherV2()
743
+ self.gather = P.Gather()
747
744
  self.assign = P.Assign()
748
745
  self.cast = P.Cast()
749
746
  self.eye = P.Eye()
@@ -989,8 +986,8 @@ class ThorAscend(Optimizer):
989
986
  kernel_hw = weight_shape[2] * weight_shape[3]
990
987
  in_channels = weight_shape[1]
991
988
  matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
992
- matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
993
- (0, self.C0 - in_channels)))(matrix_a_inv)
989
+ matrix_a_inv = P.Pad(((0, 0), (0, self.c0 - in_channels), (0, 0),
990
+ (0, self.c0 - in_channels)))(matrix_a_inv)
994
991
  return matrix_a_inv
995
992
 
996
993
  def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
@@ -1308,5 +1305,6 @@ class ThorAscend(Optimizer):
1308
1305
  gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
1309
1306
  gradients = clip_gradient(self.enable_clip_grad, gradients)
1310
1307
  lr = self.get_lr()
1308
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
1311
1309
  success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
1312
1310
  return success
@@ -16,6 +16,7 @@
16
16
  from mindspore import context
17
17
  from mindspore.nn.cell import Cell
18
18
  from mindspore.ops import operations as P
19
+ from mindspore.ops import functional as F
19
20
  from mindspore.ops.operations import _inner_ops as inner
20
21
  from mindspore.common import dtype as mstype
21
22
  from mindspore.common.tensor import Tensor
@@ -33,11 +34,11 @@ class Bijector(Cell):
33
34
  then :math:`Y = g(X)` is the random variable following the transformed distribution.
34
35
 
35
36
  Args:
36
- is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: False.
37
- is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: True.
38
- name (str): The name of the Bijector. Default: None.
39
- dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
40
- param (dict): The parameters used to initialize the Bijector. Default: None.
37
+ is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: ``False`` .
38
+ is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: ``True`` .
39
+ name (str): The name of the Bijector. Default: ``None`` .
40
+ dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: ``None`` .
41
+ param (dict): The parameters used to initialize the Bijector. Default: ``None`` .
41
42
 
42
43
  Note:
43
44
  `dtype` of bijector represents the type of the distributions that the bijector could operate on.
@@ -96,7 +97,6 @@ class Bijector(Cell):
96
97
  self.cast_base = P.Cast()
97
98
  self.dtype_base = P.DType()
98
99
  self.shape_base = P.Shape()
99
- self.fill_base = P.Fill()
100
100
  self.sametypeshape_base = inner.SameTypeShape()
101
101
  self.issubclass_base = inner.IsSubClass()
102
102
 
@@ -140,13 +140,13 @@ class Bijector(Cell):
140
140
  if self.issubclass_base(value_type, mstype.float_):
141
141
  return value
142
142
  return raise_type_error('input value of bijector', value_type, mstype.float_)
143
- dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
143
+ dtype_tensor = F.fill(self.dtype, self.shape_base(value), 0.0)
144
144
  self.sametypeshape_base(value, dtype_tensor)
145
145
  return value
146
146
 
147
147
  def _shape_mapping(self, shape):
148
- shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
149
- dist_shape_tensor = self.fill_base(
148
+ shape_tensor = F.fill(self.parameter_type, shape, 0.0)
149
+ dist_shape_tensor = F.fill(
150
150
  self.parameter_type, self.batch_shape, 0.0)
151
151
  return (shape_tensor + dist_shape_tensor).shape
152
152
 
@@ -165,7 +165,7 @@ class Bijector(Cell):
165
165
  self.common_dtype = None
166
166
  # cast value to a tensor if it is not None
167
167
  if isinstance(value, bool) or value is None:
168
- raise TypeError("{} cannot be type {}".format(name, type(value)))
168
+ raise TypeError(f"{name} cannot be type {type(value)}")
169
169
  value_t = Tensor(value)
170
170
  # if the bijector's dtype is not specified
171
171
  if self.dtype is None:
@@ -189,7 +189,7 @@ class Bijector(Cell):
189
189
  """
190
190
  Calculate batch_shape based on parameters.
191
191
  """
192
- if 'param_dict' not in self.parameters.keys():
192
+ if 'param_dict' not in self.parameters:
193
193
  return None
194
194
  param_dict = self.parameters.get('param_dict')
195
195
  broadcast_shape_tensor = None
@@ -25,7 +25,7 @@ class Exp(PowerTransform):
25
25
  Y = \exp(x).
26
26
 
27
27
  Args:
28
- name (str): The name of the Bijector. Default: 'Exp'.
28
+ name (str): The name of the Bijector. Default: ``'Exp'`` .
29
29
 
30
30
  Supported Platforms:
31
31
  ``Ascend`` ``GPU``
@@ -28,9 +28,9 @@ class GumbelCDF(Bijector):
28
28
  Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
29
29
 
30
30
  Args:
31
- loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.0.
32
- scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
33
- name (str): The name of the Bijector. Default: 'GumbelCDF'.
31
+ loc (float, list, numpy.ndarray, Tensor): The location. Default: ``0.0`` .
32
+ scale (float, list, numpy.ndarray, Tensor): The scale. Default: ``1.0`` .
33
+ name (str): The name of the Bijector. Default: ``'GumbelCDF'`` .
34
34
 
35
35
  Note:
36
36
  `scale` must be greater than zero.
@@ -25,7 +25,7 @@ class Invert(Bijector):
25
25
 
26
26
  Args:
27
27
  bijector (Bijector): Base Bijector.
28
- name (str): The name of the Bijector. Default: "". When name is set to "", it is actually
28
+ name (str): The name of the Bijector. Default: ``""`` . When name is set to "", it is actually
29
29
  'Invert' + bijector.name.
30
30
 
31
31
  Supported Platforms:
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
  """PowerTransform Bijector"""
16
16
  from mindspore.ops import operations as P
17
+ from mindspore.ops import functional as F
17
18
  from ..distribution._utils.utils import check_greater_equal_zero
18
19
  from ..distribution._utils.custom_ops import exp_generic, log_generic
19
20
  from .bijector import Bijector
@@ -34,8 +35,8 @@ class PowerTransform(Bijector):
34
35
  This Bijector is equivalent to the :class:`mindspore.nn.probability.bijector.Exp` bijector when `c=0`.
35
36
 
36
37
  Args:
37
- power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0.
38
- name (str): The name of the bijector. Default: 'PowerTransform'.
38
+ power (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``0`` .
39
+ name (str): The name of the bijector. Default: ``'PowerTransform'`` .
39
40
 
40
41
  Note:
41
42
  The dtype of `power` must be float.
@@ -68,10 +69,7 @@ class PowerTransform(Bijector):
68
69
  >>> print(ans4.shape)
69
70
  (3,)
70
71
  """
71
-
72
- def __init__(self,
73
- power=0.,
74
- name='PowerTransform'):
72
+ def __init__(self, power=0., name='PowerTransform'):
75
73
  param = dict(locals())
76
74
  param['param_dict'] = {'power': power}
77
75
  super(PowerTransform, self).__init__(name=name, param=param)
@@ -84,7 +82,6 @@ class PowerTransform(Bijector):
84
82
  self.equal_base = P.Equal()
85
83
  self.exp = exp_generic
86
84
  self.expm1 = P.Expm1()
87
- self.fill = P.Fill()
88
85
  self.log = log_generic
89
86
  self.log1p = P.Log1p()
90
87
  self.select_base = P.Select()
@@ -116,17 +113,18 @@ class PowerTransform(Bijector):
116
113
  power_local = self.cast_param_by_value(x, self.power)
117
114
 
118
115
  # broad cast the value of x and power
119
- ones = self.fill(self.dtypeop(power_local),
120
- self.shape(x + power_local), 1.)
116
+ ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
117
+ 1.)
121
118
  power_local = power_local * ones
122
119
  x = x * ones
123
- safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
124
- ones,
125
- power_local)
126
-
127
- forward_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
128
- self.exp(x),
129
- self.exp(self.log1p(x * safe_power) / safe_power))
120
+ safe_power = self.select_base(
121
+ self.equal_base(power_local,
122
+ P.ZerosLike()(power_local)), ones, power_local)
123
+
124
+ forward_v = self.select_base(
125
+ self.equal_base(power_local,
126
+ P.ZerosLike()(power_local)), self.exp(x),
127
+ self.exp(self.log1p(x * safe_power) / safe_power))
130
128
  return forward_v
131
129
 
132
130
  def _inverse(self, y):
@@ -137,17 +135,18 @@ class PowerTransform(Bijector):
137
135
  power_local = self.cast_param_by_value(y, self.power)
138
136
 
139
137
  # broad cast the value of x and power
140
- ones = self.fill(self.dtypeop(power_local),
141
- self.shape(y + power_local), 1.)
138
+ ones = F.fill(self.dtypeop(power_local), self.shape(y + power_local),
139
+ 1.)
142
140
  power_local = power_local * ones
143
141
  y = y * ones
144
- safe_power = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
145
- ones,
146
- power_local)
142
+ safe_power = self.select_base(
143
+ self.equal_base(power_local,
144
+ P.ZerosLike()(power_local)), ones, power_local)
147
145
 
148
- inverse_v = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
149
- self.log(y),
150
- self.expm1(self.log(y) * safe_power) / safe_power)
146
+ inverse_v = self.select_base(
147
+ self.equal_base(power_local,
148
+ P.ZerosLike()(power_local)), self.log(y),
149
+ self.expm1(self.log(y) * safe_power) / safe_power)
151
150
 
152
151
  return inverse_v
153
152
 
@@ -167,14 +166,15 @@ class PowerTransform(Bijector):
167
166
  power_local = self.cast_param_by_value(x, self.power)
168
167
 
169
168
  # broad cast the value of x and power
170
- ones = self.fill(self.dtypeop(power_local),
171
- self.shape(x + power_local), 1.)
169
+ ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
170
+ 1.)
172
171
  power_local = power_local * ones
173
172
  x = x * ones
174
173
 
175
- forward_log_j = self.select_base(self.equal_base(power_local, P.ZerosLike()(power_local)),
176
- x,
177
- (1. / power_local - 1) * self.log1p(x * power_local))
174
+ forward_log_j = self.select_base(
175
+ self.equal_base(power_local,
176
+ P.ZerosLike()(power_local)), x,
177
+ (1. / power_local - 1) * self.log1p(x * power_local))
178
178
 
179
179
  return forward_log_j
180
180
 
@@ -29,9 +29,9 @@ class ScalarAffine(Bijector):
29
29
  where a is the scale factor and b is the shift factor.
30
30
 
31
31
  Args:
32
- scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
33
- shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0.
34
- name (str): The name of the bijector. Default: 'ScalarAffine'.
32
+ scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
33
+ shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: ``0.0`` .
34
+ name (str): The name of the bijector. Default: ``'ScalarAffine'`` .
35
35
 
36
36
  Note:
37
37
  The dtype of `shift` and `scale` must be float.