mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  """Softplus Bijector"""
16
16
  import numpy as np
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.nn.layer.activation import LogSigmoid
19
20
  from ..distribution._utils.custom_ops import exp_generic, log_generic
20
21
  from .bijector import Bijector
@@ -31,8 +32,8 @@ class Softplus(Bijector):
31
32
  where k is the sharpness factor.
32
33
 
33
34
  Args:
34
- sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
35
- name (str): The name of the Bijector. Default: 'Softplus'.
35
+ sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
36
+ name (str): The name of the Bijector. Default: ``'Softplus'`` .
36
37
 
37
38
  Note:
38
39
  The dtype of `sharpness` must be float.
@@ -84,7 +85,6 @@ class Softplus(Bijector):
84
85
  self.abs = P.Abs()
85
86
  self.dtypeop = P.DType()
86
87
  self.cast = P.Cast()
87
- self.fill = P.Fill()
88
88
  self.greater = P.Greater()
89
89
  self.less = P.Less()
90
90
  self.log_sigmoid = LogSigmoid()
@@ -103,7 +103,7 @@ class Softplus(Bijector):
103
103
  too_large = self.greater(x, -self.threshold)
104
104
  too_small_value = self.exp(x)
105
105
  too_large_value = x
106
- ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
106
+ ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
107
107
  too_small_or_too_large = self.logicalor(too_small, too_large)
108
108
  x = self.select(too_small_or_too_large, ones, x)
109
109
  y = self.log(self.exp(x) + 1.0)
@@ -119,7 +119,7 @@ class Softplus(Bijector):
119
119
  too_large = self.greater(x, (-1) * self.threshold)
120
120
  too_small_value = self.log(x)
121
121
  too_large_value = x
122
- ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
122
+ ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
123
123
  too_small_or_too_large = self.logicalor(too_small, too_large)
124
124
  x = self.select(too_small_or_too_large, ones, x)
125
125
  y = x + self.log(self.abs(self.expm1((-1)*x)))
@@ -27,8 +27,10 @@ class WithBNNLossCell(Cell):
27
27
  Args:
28
28
  backbone (Cell): The target network.
29
29
  loss_fn (Cell): The loss function used to compute loss.
30
- dnn_factor(int, float): The coefficient of backbone's loss, which is computed by the loss function. Default: 1.
31
- bnn_factor(int, float): The coefficient of KL loss, which is the KL divergence of Bayesian layer. Default: 1.
30
+ dnn_factor(int, float): The coefficient of backbone's loss, which is computed by the loss function.
31
+ Default: ``1`` .
32
+ bnn_factor(int, float): The coefficient of KL loss, which is the KL divergence of Bayesian layer.
33
+ Default: ``1`` .
32
34
 
33
35
  Inputs:
34
36
  - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
@@ -157,11 +157,11 @@ class ConvReparam(_ConvVariational):
157
157
  stride(Union[int, tuple[int]]): The distance of kernel moving,
158
158
  an integer number represents that the height and width of movement
159
159
  are both strides, or a tuple of two integers numbers represents that
160
- height and width of movement respectively. Default: 1.
160
+ height and width of movement respectively. Default: ``1`` .
161
161
  pad_mode (str): Specifies the padding mode. The optional values are
162
- "same", "valid", and "pad". Default: "same".
162
+ ``"same"`` , ``"valid"`` , and ``"pad"`` . Default: ``"same"`` .
163
163
 
164
- - same: Adopts the way of completion. Output height and width
164
+ - ``"same"``: Adopts the way of completion. Output height and width
165
165
  will be the same as the input.
166
166
  The total number of padding will be calculated for in horizontal and
167
167
  vertical directions and evenly distributed to top and bottom,
@@ -169,43 +169,43 @@ class ConvReparam(_ConvVariational):
169
169
  will be done from the bottom and the right side. If this mode
170
170
  is set, `padding` must be 0.
171
171
 
172
- - valid: Adopts the way of discarding. The possible largest
172
+ - ``"valid"``: Adopts the way of discarding. The possible largest
173
173
  height and width of the output will be returned without padding.
174
174
  Extra pixels will be discarded. If this mode is set, `padding`
175
175
  must be 0.
176
176
 
177
- - pad: Implicit paddings on both sides of the input. The number
177
+ - ``"pad"``: Implicit paddings on both sides of the input. The number
178
178
  of `padding` will be padded to the input Tensor borders.
179
179
  `padding` must be greater than or equal to 0.
180
180
 
181
181
  padding (Union[int, tuple[int]]): Implicit paddings on both sides of
182
- the input. Default: 0.
182
+ the input. Default: ``0`` .
183
183
  dilation (Union[int, tuple[int]]): The data type is an integer or a tuple
184
184
  of 2 integers. This parameter specifies the dilation rate of the
185
185
  dilated convolution. If set to be :math:`k > 1`,
186
186
  there will be :math:`k - 1` pixels skipped for each sampling
187
187
  location. Its value must be greater or equal to 1 and bounded
188
- by the height and width of the input. Default: 1.
188
+ by the height and width of the input. Default: ``1`` .
189
189
  group (int): Splits filter into groups, `in_ channels` and
190
190
  `out_channels` must be divisible by the number of groups.
191
- Default: 1.
191
+ Default: ``1`` .
192
192
  has_bias (bool): Specifies whether the layer uses a bias vector.
193
- Default: False.
193
+ Default: ``False`` .
194
194
  weight_prior_fn (Cell): The prior distribution for weight.
195
195
  It must return a mindspore distribution instance.
196
- Default: NormalPrior. (which creates an instance of standard
196
+ Default: ``NormalPrior`` . (which creates an instance of standard
197
197
  normal distribution). The current version only supports normal distribution.
198
198
  weight_posterior_fn (function): The posterior distribution for sampling weight.
199
199
  It must be a function handle which returns a mindspore
200
- distribution instance. Default: normal_post_fn.
200
+ distribution instance. Default: ``normal_post_fn`` .
201
201
  The current version only supports normal distribution.
202
202
  bias_prior_fn (Cell): The prior distribution for bias vector. It must return
203
- a mindspore distribution. Default: NormalPrior(which creates an
203
+ a mindspore distribution. Default: ``NormalPrior`` (which creates an
204
204
  instance of standard normal distribution). The current version
205
205
  only supports normal distribution.
206
206
  bias_posterior_fn (function): The posterior distribution for sampling bias vector.
207
207
  It must be a function handle which returns a mindspore
208
- distribution instance. Default: normal_post_fn.
208
+ distribution instance. Default: ``normal_post_fn`` .
209
209
  The current version only supports normal distribution.
210
210
 
211
211
  Inputs:
@@ -136,23 +136,23 @@ class DenseReparam(_DenseVariational):
136
136
  activation (str, Cell): A regularization function applied to the output of the layer.
137
137
  The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()).
138
138
  Note that if the type of activation is Cell, it must be instantiated beforehand.
139
- Default: None.
140
- has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
139
+ Default: ``None`` .
140
+ has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` .
141
141
  weight_prior_fn (Cell): The prior distribution for weight.
142
142
  It must return a mindspore distribution instance.
143
- Default: NormalPrior. (which creates an instance of standard
143
+ Default: ``NormalPrior`` . (which creates an instance of standard
144
144
  normal distribution). The current version only supports normal distribution.
145
145
  weight_posterior_fn (function): The posterior distribution for sampling weight.
146
146
  It must be a function handle which returns a mindspore
147
- distribution instance. Default: normal_post_fn.
147
+ distribution instance. Default: ``normal_post_fn`` .
148
148
  The current version only supports normal distribution.
149
149
  bias_prior_fn (Cell): The prior distribution for bias vector. It must return
150
- a mindspore distribution. Default: NormalPrior(which creates an
150
+ a mindspore distribution. Default: ``NormalPrior`` (which creates an
151
151
  instance of standard normal distribution). The current version
152
152
  only supports normal distribution.
153
153
  bias_posterior_fn (function): The posterior distribution for sampling bias vector.
154
154
  It must be a function handle which returns a mindspore
155
- distribution instance. Default: normal_post_fn.
155
+ distribution instance. Default: ``normal_post_fn`` .
156
156
  The current version only supports normal distribution.
157
157
 
158
158
  Inputs:
@@ -230,23 +230,23 @@ class DenseLocalReparam(_DenseVariational):
230
230
  activation (str, Cell): A regularization function applied to the output of the layer.
231
231
  The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()).
232
232
  Note that if the type of activation is Cell, it must be instantiated beforehand.
233
- Default: None.
234
- has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
233
+ Default: ``None`` .
234
+ has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` .
235
235
  weight_prior_fn (Cell): The prior distribution for weight.
236
236
  It must return a mindspore distribution instance.
237
- Default: NormalPrior. (which creates an instance of standard
237
+ Default: ``NormalPrior`` . (which creates an instance of standard
238
238
  normal distribution). The current version only supports normal distribution.
239
239
  weight_posterior_fn (function): The posterior distribution for sampling weight.
240
240
  It must be a function handle which returns a mindspore
241
- distribution instance. Default: normal_post_fn.
241
+ distribution instance. Default: ``normal_post_fn`` .
242
242
  The current version only supports normal distribution.
243
243
  bias_prior_fn (Cell): The prior distribution for bias vector. It must return
244
- a mindspore distribution. Default: NormalPrior(which creates an
244
+ a mindspore distribution. Default: ``NormalPrior`` (which creates an
245
245
  instance of standard normal distribution). The current version
246
246
  only supports normal distribution.
247
247
  bias_posterior_fn (function): The posterior distribution for sampling bias vector.
248
248
  It must be a function handle which returns a mindspore
249
- distribution instance. Default: normal_post_fn.
249
+ distribution instance. Default: ``normal_post_fn`` .
250
250
  The current version only supports normal distribution.
251
251
 
252
252
  Inputs:
@@ -30,9 +30,9 @@ class NormalPrior(Cell):
30
30
 
31
31
  Args:
32
32
  dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
33
- Default: mindspore.float32.
34
- mean (int, float): Mean of normal distribution. Default: 0.
35
- std (int, float): Standard deviation of normal distribution. Default: 0.1.
33
+ Default: ``mindspore.float32`` .
34
+ mean (int, float): Mean of normal distribution. Default: ``0`` .
35
+ std (int, float): Standard deviation of normal distribution. Default: ``0.1`` .
36
36
 
37
37
  Returns:
38
38
  Cell, a normal distribution.
@@ -56,12 +56,13 @@ class NormalPosterior(Cell):
56
56
  name (str): Name prepended to trainable parameter.
57
57
  shape (list, tuple): Shape of the mean and standard deviation.
58
58
  dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
59
- Default: mindspore.float32.
60
- loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
61
- loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
62
- untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters. Default: -5.
59
+ Default: ``mindspore.float32`` .
60
+ loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: ``0`` .
61
+ loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: ``0.1`` .
62
+ untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters.
63
+ Default: ``-5`` .
63
64
  untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters.
64
- Default: 0.1.
65
+ Default: ``0.1`` .
65
66
 
66
67
  Returns:
67
68
  Cell, a normal distribution.
@@ -15,8 +15,17 @@
15
15
  """Utility functions to help distribution class."""
16
16
  import numpy as np
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.ops.operations import _inner_ops as inner
20
+ from mindspore.ops.primitive import constexpr
19
21
  from mindspore.common import dtype as mstype
22
+ from .utils import CheckTensor
23
+
24
+
25
+ @constexpr(check=False)
26
+ def _check_tensor(x, name):
27
+ CheckTensor()(x, name)
28
+ return x
20
29
 
21
30
 
22
31
  def exp_generic(input_x):
@@ -44,7 +53,6 @@ def log_generic(input_x):
44
53
  log = P.Log()
45
54
  less = P.Less()
46
55
  lessequal = P.LessEqual()
47
- fill = P.Fill()
48
56
  cast = P.Cast()
49
57
  dtype = P.DType()
50
58
  shape = P.Shape()
@@ -53,8 +61,8 @@ def log_generic(input_x):
53
61
 
54
62
  if not checktype(dtype(input_x), mstype.float_):
55
63
  input_x = cast(input_x, mstype.float32)
56
- nan = fill(dtype(input_x), shape(input_x), np.nan)
57
- inf = fill(dtype(input_x), shape(input_x), np.inf)
64
+ nan = F.fill(dtype(input_x), shape(input_x), np.nan)
65
+ inf = F.fill(dtype(input_x), shape(input_x), np.inf)
58
66
  neg_x = less(input_x, 0.0)
59
67
  nonpos_x = lessequal(input_x, 0.0)
60
68
  log_x = log(input_x)
@@ -63,6 +71,14 @@ def log_generic(input_x):
63
71
  return select(neg_x, nan, result)
64
72
 
65
73
 
74
+ def log_generic_with_check(x):
75
+ """
76
+ log generic with input check
77
+ """
78
+ _check_tensor(x, "the input of log_generic")
79
+ return log_generic(x)
80
+
81
+
66
82
  def log1p_generic(x):
67
83
  """
68
84
  Log1p ops on GPU device or when device_target == GPU.
@@ -315,7 +315,7 @@ class CheckTensor(PrimitiveWithInfer):
315
315
  def __infer__(self, x, name):
316
316
  src_type = x['dtype']
317
317
  validator.check_subclass(
318
- "input", src_type, [mstype.tensor], name["value"])
318
+ "input", src_type, [mstype.tensor_type], name["value"])
319
319
 
320
320
  out = {'shape': None,
321
321
  'dtype': None,
@@ -15,6 +15,7 @@
15
15
  """Bernoulli Distribution"""
16
16
  from mindspore.common import dtype as mstype
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.ops import composite as C
19
20
  from mindspore import _checkparam as Validator
20
21
  from .distribution import Distribution
@@ -29,10 +30,10 @@ class Bernoulli(Distribution):
29
30
  and the probability mass function as :math:`P(X = 0) = p, P(X = 1) = 1-p`.
30
31
 
31
32
  Args:
32
- probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: None.
33
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
34
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
35
- name (str): The name of the distribution. Default: 'Bernoulli'.
33
+ probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: ``None`` .
34
+ seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
35
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
36
+ name (str): The name of the distribution. Default: ``'Bernoulli'`` .
36
37
 
37
38
  Note:
38
39
  `probs` must be a proper probability (0 < p < 1).
@@ -151,7 +152,6 @@ class Bernoulli(Distribution):
151
152
  self.cast = P.Cast()
152
153
  self.const = P.ScalarToTensor()
153
154
  self.floor = P.Floor()
154
- self.fill = P.Fill()
155
155
  self.less = P.Less()
156
156
  self.shape = P.Shape()
157
157
  self.select = P.Select()
@@ -200,8 +200,8 @@ class Bernoulli(Distribution):
200
200
  MODE(B) = 1 if probs1 > 0.5 else = 0
201
201
  """
202
202
  probs1 = self._check_param_type(probs1)
203
- zeros = self.fill(self.dtype, self.shape(probs1), 0.0)
204
- ones = self.fill(self.dtype, self.shape(probs1), 1.0)
203
+ zeros = F.fill(self.dtype, self.shape(probs1), 0.0)
204
+ ones = F.fill(self.dtype, self.shape(probs1), 1.0)
205
205
  comp = self.less(0.5, probs1)
206
206
  return self.select(comp, ones, zeros)
207
207
 
@@ -278,9 +278,9 @@ class Bernoulli(Distribution):
278
278
  probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
279
279
  comp_zero = self.less(value, 0.0)
280
280
  comp_one = self.less(value, 1.0)
281
- zeros = self.fill(self.parameter_type, self.shape(
281
+ zeros = F.fill(self.parameter_type, self.shape(
282
282
  broadcast_shape_tensor), 0.0)
283
- ones = self.fill(self.parameter_type, self.shape(
283
+ ones = F.fill(self.parameter_type, self.shape(
284
284
  broadcast_shape_tensor), 1.0)
285
285
  less_than_zero = self.select(comp_zero, zeros, probs0)
286
286
  return self.select(comp_one, less_than_zero, ones)
@@ -15,6 +15,7 @@
15
15
  """Beta Distribution"""
16
16
  import numpy as np
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.ops import composite as C
19
20
  import mindspore.nn as nn
20
21
  from mindspore import _checkparam as Validator
@@ -36,12 +37,12 @@ class Beta(Distribution):
36
37
 
37
38
  Args:
38
39
  concentration1 (int, float, list, numpy.ndarray, Tensor): The concentration1,
39
- also know as alpha of the Beta distribution. Default: None.
40
+ also know as alpha of the Beta distribution. Default: ``None`` .
40
41
  concentration0 (int, float, list, numpy.ndarray, Tensor): The concentration0, also know as
41
- beta of the Beta distribution. Default: None.
42
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
43
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
44
- name (str): The name of the distribution. Default: 'Beta'.
42
+ beta of the Beta distribution. Default: ``None`` .
43
+ seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
44
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
45
+ name (str): The name of the distribution. Default: ``'Beta'`` .
45
46
 
46
47
  Note:
47
48
  - `concentration1` and `concentration0` must be greater than zero.
@@ -186,7 +187,6 @@ class Beta(Distribution):
186
187
  self.pow = P.Pow()
187
188
  self.squeeze = P.Squeeze(0)
188
189
  self.cast = P.Cast()
189
- self.fill = P.Fill()
190
190
  self.shape = P.Shape()
191
191
  self.select = P.Select()
192
192
  self.logicaland = P.LogicalAnd()
@@ -266,7 +266,7 @@ class Beta(Distribution):
266
266
  comp2 = self.greater(concentration0, 1.)
267
267
  cond = self.logicaland(comp1, comp2)
268
268
  batch_shape = self.shape(concentration1 + concentration0)
269
- nan = self.fill(self.dtype, batch_shape, np.nan)
269
+ nan = F.fill(self.dtype, batch_shape, np.nan)
270
270
  mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
271
271
  return self.select(cond, mode, nan)
272
272
 
@@ -379,7 +379,7 @@ class Beta(Distribution):
379
379
  sample_shape = (1,)
380
380
  else:
381
381
  sample_shape = origin_shape
382
- ones = self.fill(self.dtype, sample_shape, 1.0)
382
+ ones = F.fill(self.dtype, sample_shape, 1.0)
383
383
  sample_gamma1 = C.gamma(
384
384
  sample_shape, alpha=concentration1, beta=ones, seed=self.seed)
385
385
  sample_gamma2 = C.gamma(
@@ -15,7 +15,9 @@
15
15
  """Categorical Distribution"""
16
16
  import numpy as np
17
17
  from mindspore import context
18
+ from mindspore.common import Tensor
18
19
  from mindspore.ops import operations as P
20
+ from mindspore.ops import functional as F
19
21
  from mindspore.ops import composite as C
20
22
  from mindspore.ops.functional import stop_gradient
21
23
  from mindspore.ops.operations import _inner_ops as inner
@@ -26,7 +28,7 @@ from mindspore.common import dtype as mstype
26
28
  from .distribution import Distribution
27
29
  from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
28
30
  check_distribution_name
29
- from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
31
+ from ._utils.custom_ops import exp_generic, log_generic, broadcast_to, log_generic_with_check
30
32
 
31
33
 
32
34
  class Categorical(Distribution):
@@ -36,10 +38,10 @@ class Categorical(Distribution):
36
38
  and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`.
37
39
 
38
40
  Args:
39
- probs (Tensor, list, numpy.ndarray): Event probabilities. Default: None.
40
- seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
41
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
42
- name (str): The name of the distribution. Default: Categorical.
41
+ probs (Tensor, list, numpy.ndarray): Event probabilities. Default: ``None`` .
42
+ seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: ``None`` .
43
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
44
+ name (str): The name of the distribution. Default: ``Categorical`` .
43
45
 
44
46
  Note:
45
47
  `probs` must have rank at least 1, values are proper probabilities and sum to 1.
@@ -148,7 +150,6 @@ class Categorical(Distribution):
148
150
  self.dtypeop = P.DType()
149
151
  self.exp = exp_generic
150
152
  self.expand_dim = P.ExpandDims()
151
- self.fill = P.Fill()
152
153
  self.gather = P.GatherNd()
153
154
  self.greater = P.Greater()
154
155
  self.issubclass = inner.IsSubClass()
@@ -156,6 +157,7 @@ class Categorical(Distribution):
156
157
  # when the graph kernel mode is enable
157
158
  # use Log directly as akg will handle the corner cases
158
159
  self.log = P.Log() if context.get_context("enable_graph_kernel") else log_generic
160
+ self.log_with_check = P.Log() if context.get_context("enable_graph_kernel") else log_generic_with_check
159
161
  self.log_softmax = P.LogSoftmax()
160
162
  self.logicor = P.LogicalOr()
161
163
  self.logicand = P.LogicalAnd()
@@ -253,8 +255,11 @@ class Categorical(Distribution):
253
255
  probs_b = self._check_value(probs_b, 'probs_b')
254
256
  probs_b = self.cast(probs_b, self.parameter_type)
255
257
  probs_a = self._check_param_type(probs)
256
- logits_a = self.log(probs_a)
257
- logits_b = self.log(probs_b)
258
+ if probs is None:
259
+ logits_a = self.log(probs_a)
260
+ else:
261
+ logits_a = self.log_with_check(probs_a)
262
+ logits_b = self.log_with_check(probs_b)
258
263
  return self.squeeze(self.reduce_sum(
259
264
  self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
260
265
 
@@ -287,7 +292,7 @@ class Categorical(Distribution):
287
292
  # here we simulate casting to int but still keeping float dtype
288
293
  value = self.cast(value, self.dtypeop(probs))
289
294
 
290
- zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
295
+ zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
291
296
  between_zero_neone = self.logicand(self.less(value, 0,),
292
297
  self.greater(value, -1.))
293
298
  value = self.select(between_zero_neone,
@@ -323,15 +328,18 @@ class Categorical(Distribution):
323
328
  value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
324
329
  value_clipped = self.cast(value_clipped, self.index_type)
325
330
  # create index from 0 ... NumOfLabels
326
- index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
331
+ start = Tensor(0, self.index_type)
332
+ end = self.cast(self.shape(value)[0], self.index_type)
333
+ delta = Tensor(1, self.index_type)
334
+ index = self.reshape(ops.range(start, end, delta), (-1, 1))
327
335
  index = self.concat((index, value_clipped))
328
336
 
329
337
  # index into logit_pmf, fill in out_of_bound places with -inf
330
338
  # reshape into label shape N
331
339
  logits_pmf = self.gather(self.reshape(
332
340
  logits, (-1, num_classes)), index)
333
- nan = self.fill(self.dtypeop(logits_pmf),
334
- self.shape(logits_pmf), self.nan)
341
+ nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
342
+ self.nan)
335
343
  logits_pmf = self.select(out_of_bound, nan, logits_pmf)
336
344
  ans = self.reshape(logits_pmf, label_shape)
337
345
  if drop_dim:
@@ -351,7 +359,7 @@ class Categorical(Distribution):
351
359
 
352
360
  value = self.cast(value, self.dtypeop(probs))
353
361
 
354
- zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
362
+ zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
355
363
  between_zero_neone = self.logicand(
356
364
  self.less(value, 0,), self.greater(value, -1.))
357
365
  value = self.select(between_zero_neone, zeros, P.Floor()(value))
@@ -386,7 +394,7 @@ class Categorical(Distribution):
386
394
  # reshape probs and fill less_than_zero places with 0
387
395
  probs = self.reshape(probs, (-1, num_classes))
388
396
  cdf = self.gather(self.cumsum(probs, 1), index)
389
- zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
397
+ zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
390
398
  cdf = self.select(less_than_zero, zeros, cdf)
391
399
  cdf = self.reshape(cdf, label_shape)
392
400
 
@@ -417,7 +425,7 @@ class Categorical(Distribution):
417
425
  sample_shape = (1,)
418
426
 
419
427
  probs_2d = self.reshape(probs, (-1, num_classes))
420
- sample_tensor = self.fill(self.dtype, shape, 1.0)
428
+ sample_tensor = F.fill(self.dtype, shape, 1.0)
421
429
  sample_tensor = self.reshape(sample_tensor, (-1, 1))
422
430
  num_sample = self.shape(sample_tensor)[0]
423
431
  samples = C.multinomial(probs_2d, num_sample, seed=self.seed)
@@ -35,11 +35,11 @@ class Cauchy(Distribution):
35
35
  where :math:`a, b` are loc and scale parameter respectively.
36
36
 
37
37
  Args:
38
- loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: None.
39
- scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution. Default: None.
40
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
41
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
42
- name (str): The name of the distribution. Default: 'Cauchy'.
38
+ loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: ``None`` .
39
+ scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution. Default: ``None`` .
40
+ seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
41
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
42
+ name (str): The name of the distribution. Default: ``'Cauchy'`` .
43
43
 
44
44
  Note:
45
45
  `scale` must be greater than zero.
@@ -170,7 +170,6 @@ class Cauchy(Distribution):
170
170
  self.const = P.ScalarToTensor()
171
171
  self.dtypeop = P.DType()
172
172
  self.exp = exp_generic
173
- self.fill = P.Fill()
174
173
  self.less = P.Less()
175
174
  self.log = log_generic
176
175
  self.log1p = log1p_generic
@@ -15,6 +15,7 @@
15
15
  """basic"""
16
16
  from mindspore import context
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.nn.cell import Cell
19
20
  from mindspore.ops.primitive import constexpr
20
21
  from mindspore.ops.operations import _inner_ops as inner
@@ -113,7 +114,6 @@ class Distribution(Cell):
113
114
  # ops needed for the base class
114
115
  self.cast_base = P.Cast()
115
116
  self.dtype_base = P.DType()
116
- self.fill_base = P.Fill()
117
117
  self.sametypeshape_base = inner.SameTypeShape()
118
118
  self.sq_base = P.Square()
119
119
  self.sqrt_base = P.Sqrt()
@@ -194,11 +194,11 @@ class Distribution(Cell):
194
194
  if broadcast_shape is None:
195
195
  broadcast_shape = self.shape_base(arg)
196
196
  common_dtype = self.dtype_base(arg)
197
- broadcast_shape_tensor = self.fill_base(
197
+ broadcast_shape_tensor = F.fill(
198
198
  common_dtype, broadcast_shape, 1.0)
199
199
  else:
200
200
  broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
201
- broadcast_shape_tensor = self.fill_base(
201
+ broadcast_shape_tensor = F.fill(
202
202
  common_dtype, broadcast_shape, 1.0)
203
203
  arg = self.broadcast(arg, broadcast_shape_tensor)
204
204
  # check if the arguments have the same dtype
@@ -35,10 +35,10 @@ class Exponential(Distribution):
35
35
  where :math:`\lambda` is the rate of the distribution.
36
36
 
37
37
  Args:
38
- rate (int, float, list, numpy.ndarray, Tensor): The inverse scale. Default: None.
39
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
40
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
41
- name (str): The name of the distribution. Default: 'Exponential'.
38
+ rate (int, float, list, numpy.ndarray, Tensor): The inverse scale. Default: ``None`` .
39
+ seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
40
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
41
+ name (str): The name of the distribution. Default: ``'Exponential'`` .
42
42
 
43
43
  Note:
44
44
  `rate` must be strictly greater than 0.
@@ -15,6 +15,7 @@
15
15
  """Gamma Distribution"""
16
16
  import numpy as np
17
17
  from mindspore.ops import operations as P
18
+ from mindspore.ops import functional as F
18
19
  from mindspore.ops import composite as C
19
20
  import mindspore.nn as nn
20
21
  from mindspore import _checkparam as Validator
@@ -38,12 +39,12 @@ class Gamma(Distribution):
38
39
 
39
40
  Args:
40
41
  concentration (int, float, list, numpy.ndarray, Tensor): The concentration,
41
- also know as alpha of the Gamma distribution. Default: None.
42
+ also know as alpha of the Gamma distribution. Default: ``None`` .
42
43
  rate (int, float, list, numpy.ndarray, Tensor): The rate, also know as
43
- beta of the Gamma distribution. Default: None.
44
- seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
45
- dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
46
- name (str): The name of the distribution. Default: 'Gamma'.
44
+ beta of the Gamma distribution. Default: ``None`` .
45
+ seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
46
+ dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
47
+ name (str): The name of the distribution. Default: ``'Gamma'`` .
47
48
 
48
49
  Note:
49
50
  `concentration` and `rate` must be greater than zero.
@@ -185,13 +186,12 @@ class Gamma(Distribution):
185
186
  self.squeeze = P.Squeeze(0)
186
187
  self.cast = P.Cast()
187
188
  self.dtypeop = P.DType()
188
- self.fill = P.Fill()
189
189
  self.shape = P.Shape()
190
190
  self.select = P.Select()
191
191
  self.greater = P.Greater()
192
- self.lgamma = nn.LGamma()
192
+ self.lgamma = P.Lgamma()
193
193
  self.digamma = nn.DiGamma()
194
- self.igamma = nn.IGamma()
194
+ self.igamma = P.Igamma()
195
195
 
196
196
  def extend_repr(self):
197
197
  """Display instance object as string."""
@@ -265,8 +265,8 @@ class Gamma(Distribution):
265
265
  """
266
266
  concentration, rate = self._check_param_type(concentration, rate)
267
267
  mode = (concentration - 1.) / rate
268
- nan = self.fill(self.dtypeop(concentration),
269
- self.shape(concentration), np.nan)
268
+ nan = F.fill(self.dtypeop(concentration), self.shape(concentration),
269
+ np.nan)
270
270
  comp = self.greater(concentration, 1.)
271
271
  return self.select(comp, mode, nan)
272
272