mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  """Defines parameter operators with functional form."""
16
16
 
17
17
  from __future__ import absolute_import
18
+ import numpy as np
18
19
 
19
20
  from mindspore import context
20
21
  from mindspore.ops import operations as P
@@ -29,6 +30,12 @@ from mindspore.ops._primitive_cache import _get_cache_prim
29
30
  from mindspore.common.api import _function_forbid_reuse
30
31
 
31
32
 
33
+ @constexpr
34
+ def _set_prim_op_user_data(prim, key, value):
35
+ prim.add_prim_attr(key, value)
36
+ return prim
37
+
38
+
32
39
  @_function_forbid_reuse
33
40
  def random_gamma(shape, alpha, seed=None):
34
41
  r"""
@@ -41,7 +48,7 @@ def random_gamma(shape, alpha, seed=None):
41
48
  alpha (Tensor): The :math:`\alpha` distribution parameter.
42
49
  A Tensor. Must be one of the following types: half, float32, float64.
43
50
  seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
44
- Default: None, which will be treated as 0.
51
+ Default: ``None`` , which will be treated as 0.
45
52
 
46
53
  Returns:
47
54
  Tensor. The shape should be equal to the concat shape between the input `shape` and the broadcast
@@ -69,7 +76,8 @@ def random_gamma(shape, alpha, seed=None):
69
76
  (7, 5, 2)
70
77
  """
71
78
  seed1, seed2 = _get_seed(seed, "random_gamma")
72
- random_gamma_op = _get_cache_prim(P.RandomGamma)(seed1, seed2)
79
+ random_gamma_op = P.RandomGamma(seed1, seed2)
80
+ random_gamma_op = _set_prim_op_user_data(random_gamma_op, "random_cache", False)
73
81
  output = random_gamma_op(shape, alpha)
74
82
  return output
75
83
 
@@ -93,7 +101,7 @@ def standard_laplace(shape, seed=None):
93
101
  shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
94
102
  when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
95
103
  seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
96
- Default: None, which will be treated as 0.
104
+ Default: ``None`` , which will be treated as 0.
97
105
 
98
106
  Returns:
99
107
  Tensor. The shape that the input 'shape' denotes. The dtype is float32.
@@ -115,7 +123,8 @@ def standard_laplace(shape, seed=None):
115
123
  (4, 4)
116
124
  """
117
125
  seed1, seed2 = _get_seed(seed, "standard_laplace")
118
- standard_laplace_op = _get_cache_prim(P.StandardLaplace)(seed=seed1, seed2=seed2)
126
+ standard_laplace_op = P.StandardLaplace(seed=seed1, seed2=seed2)
127
+ standard_laplace_op = _set_prim_op_user_data(standard_laplace_op, "random_cache", False)
119
128
  return standard_laplace_op(shape)
120
129
 
121
130
 
@@ -127,9 +136,9 @@ def random_categorical(logits, num_sample, seed=0, dtype=mstype.int64):
127
136
  Args:
128
137
  logits (Tensor): The input tensor. 2-D Tensor with shape :math:`(batch\_size, num\_classes)`.
129
138
  num_sample (int): Number of sample to be drawn. Only constant values is allowed.
130
- seed (int): Random seed. Only constant values is allowed. Default: 0.
139
+ seed (int): Random seed. Only constant values is allowed. Default: ``0`` .
131
140
  dtype (mindspore.dtype): The type of output. Its value must be one of mindspore.int16,
132
- mindspore.int32 and mindspore.int64. Default: mindspore.int64.
141
+ mindspore.int32 and mindspore.int64. Default: ``mstype.int64`` .
133
142
 
134
143
  Returns:
135
144
  Tensor, The output Tensor with shape :math:`(batch\_size, num\_samples)`.
@@ -154,6 +163,7 @@ def random_categorical(logits, num_sample, seed=0, dtype=mstype.int64):
154
163
  (10, 8)
155
164
  """
156
165
  random_categorical_ = P.RandomCategorical(dtype)
166
+ random_categorical_ = _set_prim_op_user_data(random_categorical_, "random_cache", False)
157
167
  return random_categorical_(logits, num_sample, seed)
158
168
 
159
169
 
@@ -175,7 +185,7 @@ def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False)
175
185
  generator is seeded by a random seed. Otherwise, it is seeded by the given seed.
176
186
  offset (int): Offset used to avoid seed collision.
177
187
  numsamples (int): the number of samples to draw.
178
- replacement (bool, optional): Whether to draw with replacement or not. Default: False.
188
+ replacement (bool, optional): Whether to draw with replacement or not. Default: ``False`` .
179
189
 
180
190
  Returns:
181
191
  Tensor with the same rows as `x`, each row has `numsamples` sampled indices.
@@ -194,6 +204,8 @@ def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False)
194
204
  ``CPU``
195
205
 
196
206
  Examples:
207
+ >>> from mindspore import Tensor, ops
208
+ >>> from mindspore import dtype as mstype
197
209
  >>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
198
210
  >>> output = ops.multinomial_with_replacement(x, 2, 5, 2, True)
199
211
  >>> print(output)
@@ -201,16 +213,17 @@ def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False)
201
213
  """
202
214
  if not isinstance(seed, Tensor):
203
215
  if not isinstance(seed, int):
204
- raise TypeError("For multinomial_with_replacement,",
205
- "the input[seed] must be int, but got {}.".format(type(seed)))
216
+ raise TypeError(f"For multinomial_with_replacement,",
217
+ f"the input[seed] must be int, but got {type(seed)}.")
206
218
  seed = Tensor(seed, dtype=mstype.int64)
207
219
  if not isinstance(offset, Tensor):
208
220
  if not isinstance(offset, int):
209
- raise TypeError("For multinomial_with_replacement,",
210
- "the input[offset] must be int, but got {}.".format(type(offset)))
221
+ raise TypeError(f"For multinomial_with_replacement,",
222
+ f"the input[offset] must be int, but got {type(offset)}.")
211
223
  offset = Tensor(offset, dtype=mstype.int64)
212
- multinomial_with_replacement_ = _get_cache_prim(P.MultinomialWithReplacement)(numsamples=numsamples,
213
- replacement=replacement)
224
+ multinomial_with_replacement_ = P.MultinomialWithReplacement(numsamples=numsamples,
225
+ replacement=replacement)
226
+ multinomial_with_replacement_ = _set_prim_op_user_data(multinomial_with_replacement_, "random_cache", False)
214
227
  return multinomial_with_replacement_(x, seed, offset)
215
228
 
216
229
 
@@ -231,10 +244,10 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
231
244
  It defines the maximum possible generated value, with int32 or float32 data type.
232
245
  If dtype is int32, only one number is allowed.
233
246
  seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
234
- must be non-negative. Default: None, which will be treated as 0.
247
+ must be non-negative. Default: ``None`` , which will be treated as 0.
235
248
  dtype (mindspore.dtype): Type of the Uniform distribution. If it is int32, it generates numbers from discrete
236
249
  uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
237
- supports these two data types. Default: mindspore.float32.
250
+ supports these two data types. Default: mstype.float32.
238
251
 
239
252
  Returns:
240
253
  Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
@@ -281,11 +294,13 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
281
294
  seed1, seed2 = _get_seed(seed, "uniform")
282
295
  if const_utils.is_same_type(dtype, mstype.int32):
283
296
  random_uniform = P.UniformInt(seed1, seed2)
297
+ random_uniform = _set_prim_op_user_data(random_uniform, "random_cache", False)
284
298
  value = random_uniform(shape, minval, maxval)
285
299
  else:
286
300
  uniform_real = P.UniformReal(seed1, seed2)
287
- random_uniform = uniform_real(shape)
288
- value = random_uniform * (maxval - minval) + minval
301
+ uniform_real = _set_prim_op_user_data(uniform_real, "random_cache", False)
302
+ uniform_real = uniform_real(shape)
303
+ value = uniform_real * (maxval - minval) + minval
289
304
  return value
290
305
 
291
306
 
@@ -304,7 +319,7 @@ def standard_normal(shape, seed=None):
304
319
  shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
305
320
  when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
306
321
  seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
307
- Default: None, which will be treated as 0.
322
+ Default: ``None`` , which will be treated as 0.
308
323
 
309
324
  Returns:
310
325
  Tensor. The shape that the input 'shape' denotes. The dtype is float32.
@@ -325,7 +340,8 @@ def standard_normal(shape, seed=None):
325
340
  (4, 4)
326
341
  """
327
342
  seed1, seed2 = _get_seed(seed, "standard_normal")
328
- standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed1, seed2=seed2)
343
+ standard_normal_op = P.StandardNormal(seed=seed1, seed2=seed2)
344
+ standard_normal_op = _set_prim_op_user_data(standard_normal_op, "random_cache", False)
329
345
  return standard_normal_op(shape)
330
346
 
331
347
 
@@ -344,23 +360,25 @@ def uniform_candidate_sampler(true_classes,
344
360
  If unique=True, candidates are drawn without replacement, else unique=False with replacement.
345
361
 
346
362
  Args:
347
- true_classes (Tensor): A Tensor. The target classes with a Tensor shape of :math:`(batch_size, num_true)` .
363
+ true_classes (Tensor): A Tensor. The target classes with a Tensor shape of :math:`(batch\_size, num\_true)` .
348
364
  num_true (int): The number of target classes in each training example.
349
365
  num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
350
366
  of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
351
367
  unique (bool): Whether all sampled classes in a batch are unique.
352
368
  range_max (int): The number of possible classes, must be positive.
353
369
  seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
354
- the seed will be replaced with a randomly generated value. Default: 0.
355
- remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
370
+ the seed will be replaced with a randomly generated value. Default: ``0`` .
371
+ remove_accidental_hits (bool): Whether accidental hit is removed.
372
+ Accidental hit is when one of the true classes matches one of the sample classes.
373
+ Set ``True`` to remove which accidentally sampling the true class as sample class. Default: ``False`` .
356
374
 
357
375
  Returns:
358
376
  - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
359
- Shape: :math:`(num_sampled, )` .
377
+ Shape: :math:`(num\_sampled, )` .
360
378
  - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
361
- of true_classes. Shape: :math:`(batch_size, num_true)` .
379
+ of true_classes. Shape: :math:`(batch\_size, num\_true)` .
362
380
  - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
363
- each of sampled_candidates. Shape: :math:`(num_sampled, )` .
381
+ each of sampled_candidates. Shape: :math:`(num\_sampled, )` .
364
382
 
365
383
  Raises:
366
384
  TypeError: If neither `num_true` nor `num_sampled` is an int.
@@ -372,6 +390,8 @@ def uniform_candidate_sampler(true_classes,
372
390
  ``Ascend`` ``GPU`` ``CPU``
373
391
 
374
392
  Examples:
393
+ >>> import numpy as np
394
+ >>> from mindspore import Tensor, ops
375
395
  >>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64))
376
396
  >>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
377
397
  >>> print(output1.shape)
@@ -381,12 +401,13 @@ def uniform_candidate_sampler(true_classes,
381
401
  >>> print(output3.shape)
382
402
  (3,)
383
403
  """
384
- sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true,
385
- num_sampled,
386
- unique,
387
- range_max,
388
- seed=seed,
389
- remove_accidental_hits=remove_accidental_hits)
404
+ sampler_op = P.UniformCandidateSampler(num_true,
405
+ num_sampled,
406
+ unique,
407
+ range_max,
408
+ seed=seed,
409
+ remove_accidental_hits=remove_accidental_hits)
410
+ sampler_op = _set_prim_op_user_data(sampler_op, "random_cache", False)
390
411
  sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes)
391
412
  return sampled_candidates, true_expected_count, sampled_expected_count
392
413
 
@@ -403,15 +424,15 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
403
424
 
404
425
  Args:
405
426
  shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose
406
- dtype is mindspore.dtype.int32 or mindspore.dtype.int64.
427
+ dtype is mstype.int32 or mstype.int64.
407
428
  rate (Tensor): The :math:`μ` parameter the distribution is constructed with.
408
- It represents the mean of the distribution
409
- and also the variance of the distribution. It should be a `Tensor` whose dtype is mindspore.dtype.int64,
410
- mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
429
+ It represents the mean of the distribution
430
+ and also the variance of the distribution. It should be a `Tensor` whose dtype is mstype.int64,
431
+ mstype.int32, mstype.float64, mstype.float32 or mstype.float16.
411
432
  seed (int, optional): Seed is used as entropy source for the random number engines to generate pseudo-random
412
- numbers and must be non-negative. Default: None, which will be treated as 0.
413
- dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32,
414
- mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32.
433
+ numbers and must be non-negative. Default: ``None`` , which will be treated as 0.
434
+ dtype (mindspore.dtype): The data type of output: ``mstype.int64``, ``mstype.int32``,
435
+ ``mstype.float64``, ``mstype.float32`` or ``mstype.float16``. Default: ``mstype.float32``.
415
436
 
416
437
  Returns:
417
438
  A Tensor whose shape is `mindspore.concat(['shape', mindspore.shape('rate')], axis=0)` and data type is equal to
@@ -419,14 +440,14 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
419
440
 
420
441
  Raises:
421
442
  TypeError: If `shape` is not a Tensor.
422
- TypeError: If datatype of `shape` is not mindspore.dtype.int64 nor mindspore.dtype.int32.
443
+ TypeError: If datatype of `shape` is not mstype.int64 nor mstype.int32.
423
444
  ValueError: If shape of `shape` is not 1-D.
424
445
  TypeError: If `rate` is not a Tensor nor a scalar.
425
- TypeError: If datatype of `rate` is not in [mindspore.dtype.int64, mindspore.dtype.int32,
426
- mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
446
+ TypeError: If datatype of `rate` is not in [mstype.int64, mstype.int32,
447
+ mstype.float64, mstype.float32 or mstype.float16].
427
448
  TypeError: If `seed` is not a non-negtive int.
428
- TypeError: If `dtype` is not in [mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
429
- mindspore.dtype.float32 nor mindspore.dtype.float16].
449
+ TypeError: If `dtype` is not in [mstype.int64, mstype.int32, mstype.float64,
450
+ mstype.float32 nor mstype.float16].
430
451
  ValueError: If any element of input `shape` tensor is not positive.
431
452
 
432
453
  Supported Platforms:
@@ -450,7 +471,8 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
450
471
  (2, 2) Int64
451
472
  """
452
473
  seed1, seed2 = _get_seed(seed, "random_poisson")
453
- prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype)
474
+ prim_random_poisson = P.RandomPoisson(seed1, seed2, dtype)
475
+ prim_random_poisson = _set_prim_op_user_data(prim_random_poisson, "random_cache", False)
454
476
  value = prim_random_poisson(shape, rate)
455
477
  return value
456
478
 
@@ -463,7 +485,7 @@ def shuffle(x, seed=None):
463
485
  Args:
464
486
  x (Tensor): The Tensor need be shuffled.
465
487
  seed (int, optional): Random seed used for random number generation, must be non-negative. If `seed` is 0,
466
- which will be replaced with a randomly generated value. Default: None, which will be treated as 0.
488
+ which will be replaced with a randomly generated value. Default: ``None`` , which will be treated as 0.
467
489
 
468
490
  Returns:
469
491
  Tensor. The shape and type are the same as the input `x`.
@@ -475,13 +497,17 @@ def shuffle(x, seed=None):
475
497
  ``Ascend`` ``GPU`` ``CPU``
476
498
 
477
499
  Examples:
500
+ >>> import numpy as np
501
+ >>> from mindspore import Tensor, ops
502
+ >>> from mindspore import dtype as mstype
478
503
  >>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32)
479
504
  >>> output = ops.shuffle(x, seed=1)
480
505
  >>> print(output)
481
- (3. 4. 2. 1.)
506
+ [3. 4. 2. 1.]
482
507
  """
483
508
  seed, seed2 = _get_seed(seed, "shuffle")
484
- random_shuffle_ = _get_cache_prim(RandomShuffle)(seed=seed, seed2=seed2)
509
+ random_shuffle_ = RandomShuffle(seed=seed, seed2=seed2)
510
+ random_shuffle_ = _set_prim_op_user_data(random_shuffle_, "random_cache", False)
485
511
  output = random_shuffle_(x)
486
512
  return output
487
513
 
@@ -496,13 +522,13 @@ def log_uniform_candidate_sampler(true_classes, num_true=1, num_sampled=5, uniqu
496
522
  Args:
497
523
  true_classes (Tensor): The target classes. With data type of int64 and
498
524
  shape :math:`(batch\_size, num\_true)` .
499
- num_true (int): The number of target classes per training example. Default: 1.
500
- num_sampled (int): The number of classes to randomly sample. Default: 5.
501
- unique (bool): Determines whether sample with rejection. If `unique` is True,
502
- all sampled classes in a batch are unique. Default: True.
503
- range_max (int): The number of possible classes. When `unique` is True,
504
- `range_max` must be greater than or equal to `num_sampled`. Default: 5.
505
- seed (int): Random seed, must be non-negative. Default: 0.
525
+ num_true (int): The number of target classes per training example. Default: ``1`` .
526
+ num_sampled (int): The number of classes to randomly sample. Default: ``5`` .
527
+ unique (bool): Determines whether sample with rejection. If `unique` is ``True`` ,
528
+ all sampled classes in a batch are unique. Default: ``True`` .
529
+ range_max (int): The number of possible classes. When `unique` is ``True`` ,
530
+ `range_max` must be greater than or equal to `num_sampled`. Default: ``5`` .
531
+ seed (int): Random seed, must be non-negative. Default: ``0`` .
506
532
 
507
533
  Returns:
508
534
  Tuple of 3 Tensors.
@@ -535,7 +561,8 @@ def log_uniform_candidate_sampler(true_classes, num_true=1, num_sampled=5, uniqu
535
561
 
536
562
  """
537
563
 
538
- sampler = _get_cache_prim(P.LogUniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed)
564
+ sampler = P.LogUniformCandidateSampler(num_true, num_sampled, unique, range_max, seed)
565
+ sampler = _set_prim_op_user_data(sampler, "random_cache", False)
539
566
  return sampler(true_classes)
540
567
 
541
568
 
@@ -552,9 +579,9 @@ def choice_with_mask(input_x, count=256, seed=None):
552
579
  Args:
553
580
  input_x (Tensor[bool]): The input tensor.
554
581
  The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
555
- count (int, optional): Number of items expected to get and the number must be greater than 0. Default: 256.
582
+ count (int, optional): Number of items expected to get and the number must be greater than 0. Default: ``256`` .
556
583
  seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
557
- Default: None, which will be treated as 0.
584
+ Default: ``None`` , which will be treated as 0.
558
585
 
559
586
  Returns:
560
587
  Two tensors, the first one is the index tensor and the other one is the mask tensor.
@@ -571,6 +598,8 @@ def choice_with_mask(input_x, count=256, seed=None):
571
598
  ``Ascend`` ``GPU`` ``CPU``
572
599
 
573
600
  Examples:
601
+ >>> import numpy as np
602
+ >>> from mindspore import Tensor, ops
574
603
  >>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
575
604
  >>> output_y, output_mask = ops.choice_with_mask(input_x)
576
605
  >>> result = output_y.shape
@@ -581,7 +610,8 @@ def choice_with_mask(input_x, count=256, seed=None):
581
610
  (256,)
582
611
  """
583
612
  seed1, seed2 = _get_seed(seed, "choice_with_mask")
584
- choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed1, seed2=seed2)
613
+ choice_with_mask_ = RandomChoiceWithMask(count=count, seed=seed1, seed2=seed2)
614
+ choice_with_mask_ = _set_prim_op_user_data(choice_with_mask_, "random_cache", False)
585
615
  output = choice_with_mask_(input_x)
586
616
  return output
587
617
 
@@ -600,19 +630,23 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
600
630
  Returns the tensor with the determined shape inferred by n, the random numbers in it drawn from the data range
601
631
  that a given type can represent.
602
632
 
633
+ .. warning::
634
+ This is an experimental API that is subject to change or deletion.
635
+
603
636
  Args:
604
637
  n (Union[Tensor, int]): The input n Tensor with shape: () or (1,) and with data type of int64.
605
638
  The value of `n` must be greater than zero.
606
- seed (int, optional): Random seed. Default: 0. When seed is -1(only negative value), offset is 0,
639
+ seed (int, optional): Random seed. Default: ``0`` . When seed is -1(only negative value), offset is 0,
607
640
  it's determined by time.
608
641
  offset (int, optional): Offset to generate random numbers. Priority is higher than random seed.
609
- Default: 0. It must be non-negative.
642
+ Default: ``0`` . It must be non-negative.
610
643
  dtype (mindspore.dtype, optional): The type of output.
611
644
  Its value must be one of the following types: int32, int16, int8,
612
- uint8, int64, float64, float32, float16. Default: int64.
645
+ uint8, int64, float64, float32, float16. Default: mstype.int64.
613
646
 
614
647
  Returns:
615
- Tensor. Its shape is specified by the required args `n`. Its type is spcified by `dtype`. Otherwise is default.
648
+ Tensor. Its shape is specified by the required args `n`. Its type is specified by `dtype`.
649
+ Otherwise is default.
616
650
 
617
651
  Raises:
618
652
  TypeError: If `dtype` is not allowed.
@@ -624,6 +658,8 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
624
658
  ``CPU``
625
659
 
626
660
  Examples:
661
+ >>> from mindspore import ops
662
+ >>> from mindspore import dtype as mstype
627
663
  >>> n = 4
628
664
  >>> seed = 0
629
665
  >>> offset = 0
@@ -633,7 +669,8 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
633
669
  """
634
670
  if not isinstance(n, Tensor):
635
671
  n = Tensor(n)
636
- randperm_ = _get_cache_prim(RandpermV2)(dtype=dtype)
672
+ randperm_ = RandpermV2(dtype=dtype)
673
+ randperm_ = _set_prim_op_user_data(randperm_, "random_cache", False)
637
674
  return randperm_(n, seed, offset)
638
675
 
639
676
 
@@ -645,17 +682,15 @@ def normal(shape, mean, stddev, seed=None):
645
682
  Args:
646
683
  shape (tuple): The shape of random tensor to be generated.
647
684
  The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
648
- mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak,
649
- with data type in [int8, int16, int32, int64, float16, float32].
650
- stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0,
651
- with data type in [int8, int16, int32, int64, float16, float32].
685
+ mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak.
686
+ stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0.
652
687
  seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
653
- The value must be non-negative. Default: None, which will be treated as 0.
688
+ The value must be non-negative. Default: ``None`` , which will be treated as 0.
654
689
 
655
690
  Returns:
656
691
  Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
657
692
  of `mean` and `stddev`.
658
- The dtype is float32.
693
+ The dtype is [float32, float64].
659
694
 
660
695
  Supported Platforms:
661
696
  ``Ascend`` ``GPU`` ``CPU``
@@ -692,12 +727,9 @@ def normal(shape, mean, stddev, seed=None):
692
727
  mean = Tensor(mean)
693
728
  if not isinstance(stddev, Tensor):
694
729
  stddev = Tensor(stddev)
695
- mean_dtype = F.dtype(mean)
696
- stddev_dtype = F.dtype(stddev)
697
- const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
698
- const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
699
730
  seed1, seed2 = _get_seed(seed, "normal")
700
731
  stdnormal = P.StandardNormal(seed1, seed2)
732
+ stdnormal = _set_prim_op_user_data(stdnormal, "random_cache", False)
701
733
  _check_shape(shape)
702
734
  random_normal = stdnormal(shape)
703
735
  value = random_normal * stddev + mean
@@ -721,7 +753,7 @@ def laplace(shape, mean, lambda_param, seed=None):
721
753
  lambda_param (Tensor): The parameter used for controlling the variance of this random distribution. The
722
754
  variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type.
723
755
  seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
724
- Default: None, which will be treated as 0.
756
+ Default: ``None`` , which will be treated as 0.
725
757
 
726
758
  Returns:
727
759
  Tensor. The shape should be the broadcasted shape of input `shape` and shapes of `mean` and `lambda_param`.
@@ -747,6 +779,7 @@ def laplace(shape, mean, lambda_param, seed=None):
747
779
  const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace")
748
780
  seed1, seed2 = _get_seed(seed, "laplace")
749
781
  stdlaplace = P.StandardLaplace(seed1, seed2)
782
+ stdlaplace = _set_prim_op_user_data(stdlaplace, "random_cache", False)
750
783
  _check_shape(shape)
751
784
  rnd = stdlaplace(shape)
752
785
  value = rnd * lambda_param + mean
@@ -763,7 +796,7 @@ def gamma(shape, alpha, beta, seed=None):
763
796
  alpha (Tensor): The :math:`\alpha` distribution parameter. It should be greater than 0 with float32 data type.
764
797
  beta (Tensor): The :math:`\beta` distribution parameter. It should be greater than 0 with float32 data type.
765
798
  seed (int): Seed is used as entropy source for the random number engines to generate
766
- pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
799
+ pseudo-random numbers, must be non-negative. Default: ``None`` , which will be treated as ``0`` .
767
800
 
768
801
  Returns:
769
802
  Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
@@ -804,30 +837,29 @@ def gamma(shape, alpha, beta, seed=None):
804
837
  >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
805
838
  >>> beta = Tensor(np.array([1.0, 2]), mindspore.float32)
806
839
  >>> output = ops.gamma(shape, alpha, beta, seed=5)
807
- >>> result = output.shape
808
840
  >>> print(output)
809
- [[[ 2.2132034 5.8855834]]
810
- [ 3.3981476 7.5805717]
811
- [[ 3.3981476 7.5805717]]
812
- [ 3.7190282 19.941492]
813
- [[ 2.9512358 2.5969937]]
814
- [ 3.786061 5.160872 ]]]
841
+ [[[ 2.2132034 5.8855834]
842
+ [ 3.8825176 8.6066265]]
843
+ [[ 3.3981476 7.5805717]
844
+ [ 3.7190282 19.941492 ]]
845
+ [[ 2.9512358 2.5969937]
846
+ [ 3.786061 5.160872 ]]]
815
847
  >>> # case 4: beta_shape is (2, 1), the output is different.
816
848
  >>> shape = (3, 1, 2)
817
849
  >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
818
850
  >>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32)
819
851
  >>> output = ops.gamma(shape, alpha, beta, seed=5)
820
- >>> result = output.shape
821
852
  >>> print(output)
822
- [[[ 5.6085486 7.8280783]]
823
- [ 15.97684 16.116285]
824
- [[ 1.8347423 1.713663]]
825
- [ 3.2434065 15.667398]
826
- [[ 4.2922077 7.3365674]]
853
+ [[[ 5.6085486 7.8280783]
854
+ [ 15.97684 16.116285]]
855
+ [[ 1.8347423 1.713663]
856
+ [ 3.2434065 15.667398]]
857
+ [[ 4.2922077 7.3365674]
827
858
  [ 5.3876944 13.159832 ]]]
828
859
  """
829
860
  seed1, seed2 = _get_seed(seed, "gamma")
830
861
  gamma_v = P.Gamma(seed1, seed2)
862
+ gamma_v = _set_prim_op_user_data(gamma_v, "random_cache", False)
831
863
  value = gamma_v(shape, alpha, beta)
832
864
  return value
833
865
 
@@ -845,13 +877,13 @@ def _generate_shapes(shape):
845
877
  elif isinstance(shape[0], tuple):
846
878
  size = shape[0]
847
879
  else:
848
- raise TypeError("If the length of the argument 'shape' is 1, the type of the argument 'shape' must be "
849
- "one of ['int', 'list', 'tuple'], but got ", shape[0])
880
+ raise TypeError(f"If the length of the argument 'shape' is 1, the type of the argument 'shape' must be "
881
+ f"one of ['int', 'list', 'tuple'], but got {shape[0]}.")
850
882
  else:
851
883
  for value in shape:
852
884
  if not isinstance(value, int):
853
- raise TypeError("If the length of the argument 'shape' is > 1, the type of the argument 'shape' must "
854
- "all be int, but got ", value)
885
+ raise TypeError(f"If the length of the argument 'shape' is > 1, the type of the argument 'shape' must "
886
+ f"all be int, but got {value}.")
855
887
  size = shape
856
888
  return size
857
889
 
@@ -867,8 +899,8 @@ def rand(*size, dtype=None, seed=None):
867
899
 
868
900
  Keyword Args:
869
901
  dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
870
- `mindspore.float32` will be applied. Default: None.
871
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
902
+ `mindspore.float32` will be applied. Default: ``None`` .
903
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
872
904
 
873
905
  Returns:
874
906
  Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
@@ -895,6 +927,7 @@ def rand(*size, dtype=None, seed=None):
895
927
  cast_ = P.Cast()
896
928
  seed1, seed2 = _get_seed(seed, 'rand')
897
929
  rand_op = P.UniformReal(seed1, seed2)
930
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
898
931
  output = rand_op(shape)
899
932
  return cast_(output, dtype)
900
933
 
@@ -907,11 +940,11 @@ def rand_like(input, seed=None, *, dtype=None):
907
940
 
908
941
  Args:
909
942
  input (Tensor): Input Tensor to specify the output shape and its default dtype.
910
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
943
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
911
944
 
912
945
  Keyword Args:
913
946
  dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
914
- the same dtype of `input` will be applied. Default: None.
947
+ the same dtype of `input` will be applied. Default: ``None`` .
915
948
 
916
949
  Returns:
917
950
  Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
@@ -932,15 +965,17 @@ def rand_like(input, seed=None, *, dtype=None):
932
965
  [[4.1702199e-01 9.9718481e-01 7.2032452e-01]
933
966
  [9.3255734e-01 1.1438108e-04 1.2812445e-01]]
934
967
  """
935
-
968
+ if not isinstance(input, Tensor):
969
+ raise TypeError(f"For 'rand_like', the 'input' must be a Tensor, but got {type(input)}")
936
970
  if dtype is None:
937
971
  dtype = input.dtype
938
- elif dtype not in mstype.float_type:
972
+ if dtype not in mstype.float_type:
939
973
  raise ValueError(f"For 'rand_like', the 'dtype' must be a float type, but got {dtype}.")
940
974
  shape = input.shape
941
975
  cast_ = P.Cast()
942
976
  seed1, seed2 = _get_seed(seed, 'rand_like')
943
977
  rand_op = P.UniformReal(seed1, seed2)
978
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
944
979
  output = rand_op(shape)
945
980
  return cast_(output, dtype)
946
981
 
@@ -956,8 +991,8 @@ def randn(*size, dtype=None, seed=None):
956
991
 
957
992
  Keyword Args:
958
993
  dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
959
- `mindspore.float32` will be used. Default: None.
960
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
994
+ `mindspore.float32` will be used. Default: ``None`` .
995
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
961
996
 
962
997
  Returns:
963
998
  Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
@@ -985,6 +1020,7 @@ def randn(*size, dtype=None, seed=None):
985
1020
  cast_ = P.Cast()
986
1021
  seed1, seed2 = _get_seed(seed, 'randn')
987
1022
  rand_op = P.StandardNormal(seed1, seed2)
1023
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
988
1024
  output = rand_op(shape)
989
1025
  return cast_(output, dtype)
990
1026
 
@@ -997,11 +1033,11 @@ def randn_like(input, seed=None, *, dtype=None):
997
1033
 
998
1034
  Args:
999
1035
  input (Tensor): Input Tensor to specify the output shape and its default dtype.
1000
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
1036
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
1001
1037
 
1002
1038
  Keyword Args:
1003
1039
  dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1004
- `mindspore.float32` will be used. Default: None.
1040
+ `mindspore.float32` will be used. Default: ``None`` .
1005
1041
 
1006
1042
  Returns:
1007
1043
  Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
@@ -1022,14 +1058,17 @@ def randn_like(input, seed=None, *, dtype=None):
1022
1058
  [[ 0.30639967 -0.42438635 -0.20454668]
1023
1059
  [-0.4287376 1.3054721 0.64747655]]
1024
1060
  """
1061
+ if not isinstance(input, Tensor):
1062
+ raise TypeError(f"For 'randn_like', the 'input' must be a Tensor, but got {type(input)}")
1025
1063
  if dtype is None:
1026
- dtype = input.dtype
1027
- elif dtype not in mstype.float_type:
1064
+ dtype = mstype.float32
1065
+ if dtype not in mstype.float_type:
1028
1066
  raise ValueError(f"For 'randn_like', the 'dtype' must be a float type, but got {dtype}.")
1029
1067
  shape = input.shape
1030
1068
  cast_ = P.Cast()
1031
1069
  seed1, seed2 = _get_seed(seed, 'randn_like')
1032
1070
  rand_op = P.StandardNormal(seed1, seed2)
1071
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1033
1072
  output = rand_op(shape)
1034
1073
  return cast_(output, dtype)
1035
1074
 
@@ -1043,11 +1082,11 @@ def randint(low, high, size, seed=None, *, dtype=None):
1043
1082
  low (int): Start value of interval.
1044
1083
  high (int): End value of interval.
1045
1084
  size (tuple): Shape of the new tensor.
1046
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
1085
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
1047
1086
 
1048
1087
  Keyword Args:
1049
- dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If None,
1050
- `mindspore.int64` will be used. Default: None.
1088
+ dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If ``None`` ,
1089
+ `mindspore.int64` will be used. Default: ``None`` .
1051
1090
 
1052
1091
  Returns:
1053
1092
  Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
@@ -1075,11 +1114,14 @@ def randint(low, high, size, seed=None, *, dtype=None):
1075
1114
  raise ValueError(f"For 'randint', the 'dtype' must be an int type, but got {dtype}.")
1076
1115
  if not isinstance(size, tuple):
1077
1116
  raise ValueError(f"For 'randint', the input 'size' must be a tuple, but got {size}.")
1078
- if not isinstance(low, int) or not isinstance(high, int):
1079
- raise TypeError(f"For 'randint', 'low' and 'high' must be an int, but got {type(low)} and {type(high)}.")
1117
+ if not isinstance(low, int) or isinstance(low, bool):
1118
+ raise TypeError(f"For 'randint_like', 'low' must be an int, but got {type(low)}.")
1119
+ if not isinstance(high, int) or isinstance(high, bool):
1120
+ raise TypeError(f"For 'randint_like', 'high' must be an int, but got {type(high)}.")
1080
1121
  seed1, seed2 = _get_seed(seed, 'randint')
1081
1122
  cast_ = P.Cast()
1082
1123
  rand_op = P.UniformInt(seed1, seed2)
1124
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1083
1125
  low_ = Tensor(low, mstype.int32)
1084
1126
  high_ = Tensor(high, mstype.int32)
1085
1127
  output = rand_op(size, low_, high_)
@@ -1096,11 +1138,11 @@ def randint_like(input, low, high, seed=None, *, dtype=None):
1096
1138
  input (Tensor): Input Tensor to specify the output shape and its default dtype.
1097
1139
  low(int): Start value of interval.
1098
1140
  high(int): End value of interval.
1099
- seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
1141
+ seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
1100
1142
 
1101
1143
  Keyword Args:
1102
- dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If None,
1103
- `mindspore.int64` will be used. Default is `mindspore.int64`.
1144
+ dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If ``None`` ,
1145
+ the same dtype of `input` will be applied. Default: ``None`` .
1104
1146
 
1105
1147
  Returns:
1106
1148
  Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
@@ -1121,15 +1163,20 @@ def randint_like(input, low, high, seed=None, *, dtype=None):
1121
1163
  [[4 9 7]
1122
1164
  [9 1 2]]
1123
1165
  """
1166
+ if not isinstance(input, Tensor):
1167
+ raise TypeError(f"For 'randint_like', the 'input' must be a Tensor, but got {type(input)}")
1124
1168
  if dtype is None:
1125
1169
  dtype = input.dtype
1126
- elif dtype not in mstype.int_type:
1170
+ if dtype not in mstype.int_type:
1127
1171
  raise ValueError(f"For 'randint_like', the 'dtype' must be an int type, but got {dtype}.")
1128
- if not isinstance(low, int) or not isinstance(high, int):
1129
- raise TypeError(f"For 'randint_like', 'low' and 'high' must be an int, but got {type(low)} and {type(high)}.")
1172
+ if not isinstance(low, int) or isinstance(low, bool):
1173
+ raise TypeError(f"For 'randint_like', 'low' must be an int, but got {type(low)}.")
1174
+ if not isinstance(high, int) or isinstance(high, bool):
1175
+ raise TypeError(f"For 'randint_like', 'high' must be an int, but got {type(high)}.")
1130
1176
  size = input.shape
1131
1177
  seed1, seed2 = _get_seed(seed, 'randint_like')
1132
1178
  rand_op = P.UniformInt(seed1, seed2)
1179
+ rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1133
1180
  cast_ = P.Cast()
1134
1181
  low_ = Tensor(low, mstype.int32)
1135
1182
  high_ = Tensor(high, mstype.int32)
@@ -1152,7 +1199,7 @@ def poisson(shape, mean, seed=None):
1152
1199
  The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1153
1200
  mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
1154
1201
  seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
1155
- and must be non-negative. Default: None, which will be treated as 0.
1202
+ and must be non-negative. Default: ``None`` , which will be treated as 0.
1156
1203
 
1157
1204
  Returns:
1158
1205
  Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
@@ -1186,6 +1233,7 @@ def poisson(shape, mean, seed=None):
1186
1233
  """
1187
1234
  seed1, seed2 = _get_seed(seed, "poisson")
1188
1235
  random_poisson_op = P.Poisson(seed1, seed2)
1236
+ random_poisson_op = _set_prim_op_user_data(random_poisson_op, "random_cache", False)
1189
1237
  value = random_poisson_op(shape, mean)
1190
1238
  return value
1191
1239
 
@@ -1196,17 +1244,42 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1196
1244
  Returns a tensor sampled from the multinomial probability distribution located in the corresponding
1197
1245
  row of the input tensor.
1198
1246
 
1247
+ The polynomial distribution is a probability distribution that generalizes the binomial distribution formula to
1248
+ multiple states. In the polynomial distribution, each event has a fixed probability, and the sum of these
1249
+ probabilities is 1. The purpose of the `mindspore.ops.multinomial` interface is to perform `num_samples` sampling
1250
+ on the input `input`, and the output tensor is the index of the input tensor for each sampling.
1251
+ The values in `input` represent the probability of selecting the corresponding index for each sampling.
1252
+
1253
+ Here is an extreme example for better understanding. Suppose we have an input probability tensor with
1254
+ values `Tensor([90 / 100, 10 / 100, 0], mindspore.float32)`, which means we can sample three indices,
1255
+ namely index 0, index 1, and index 2, with probabilities of 90%, 10%, and 0%, respectively. We perform n samplings,
1256
+ and the resulting sequence is the calculation result of the polynomial distribution, with a length equal to the
1257
+ number of samplings.
1258
+
1259
+ In case 1 of the sample code, we perform two non-replacement samplings (`replacement` is `False`).
1260
+ The calculation result is most likely `[0, 1]`, and less likely `[1, 0]`. Since the probability of selecting
1261
+ index 0 is 90% for each sampling, the first result is most likely to be index 0. Since the probability of selecting
1262
+ index 2 is 0, index 2 cannot appear in the sampling result. Therefore, the second result must be index 1,
1263
+ and the resulting sequence is `[0, 1]`.
1264
+
1265
+ In case 2 of the sample code, we perform 10 replacement samplings (`replacement` is `True`).
1266
+ As expected, about 90% of the sampling results are index 0.
1267
+
1268
+ In case 3 of the sample code, we extend the input to 2 dimensions, and the sampling results
1269
+ in each dimension also match our sampling expectations.
1270
+
1199
1271
  Note:
1200
1272
  The rows of input do not need to sum to one (in which case we use the values as weights),
1201
- but must be non-negative, finite and have a non-zero sum.
1273
+ but must be non-negative, finite and have a non-zero sum. When using values as weights, it can be understood as
1274
+ normalizing the input along the last dimension.
1202
1275
 
1203
1276
  Args:
1204
1277
  input (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
1205
1278
  float32 data type.
1206
1279
  num_samples (int): Number of samples to draw.
1207
- replacement (bool, optional): Whether to draw with replacement or not, default: True.
1280
+ replacement (bool, optional): Whether to draw with replacement or not. Default: ``True`` .
1208
1281
  seed (int, optional): Seed is used as entropy source for the random number engines to generate
1209
- pseudo-random numbers, must be non-negative. Default: None.
1282
+ pseudo-random numbers, must be non-negative. Default: ``None`` .
1210
1283
 
1211
1284
  Returns:
1212
1285
  Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
@@ -1225,65 +1298,86 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1225
1298
  >>> from mindspore import Tensor, ops
1226
1299
  >>> from mindspore import dtype as mstype
1227
1300
  >>> # case 1: The output is random, and the length of the output is the same as num_sample.
1228
- >>> input = Tensor([0, 9, 4, 0], mindspore.float32)
1229
- >>> output = ops.multinomial(input, 2)
1230
- >>> # print(output)
1231
- >>> # [1 2] or [2 1]
1232
- >>> # the case where the result is [2 1] in multiple times.
1233
- >>> # This is because the value corresponding to the index 1 is larger than the value of the index 2.
1234
- >>> print(len(output))
1301
+ >>> # replacement is False.
1302
+ >>> input1 = Tensor([90 / 100, 10 / 100, 0], mindspore.float32)
1303
+ >>> input2 = Tensor([90, 10, 0], mindspore.float32)
1304
+ >>> # input1 and input2 have the same meaning.
1305
+ >>> output1 = ops.multinomial(input1, 2, replacement=False)
1306
+ >>> output2 = ops.multinomial(input2, 2, replacement=False)
1307
+ >>> # print(output1)
1308
+ >>> # [0 1]
1309
+ >>> # print(output2)
1310
+ >>> # [0 1]
1311
+ >>> print(len(output1))
1312
+ 2
1313
+ >>> print(len(output2))
1235
1314
  2
1236
1315
  >>> # case 2: The output is random, and the length of the output is the same as num_sample.
1237
- >>> # replacement is False(Default).
1238
- >>> # If the extracted value is 0, the index value of 1 will be returned.
1239
- >>> input = Tensor([0, 9, 4, 0], mstype.float32)
1240
- >>> output = ops.multinomial(input, 4)
1241
- >>> print(output)
1242
- [1 1 2 1]
1243
- >>> # case 3: The output is random, num_sample == x_length = 4, and replacement is True,
1244
- >>> # Can extract the same elements。
1245
- >>> input = Tensor([0, 9, 4, 0], mstype.float32)
1246
- >>> output = ops.multinomial(input, 4, True)
1247
- >>> print(output)
1248
- [1 1 2 2]
1316
+ >>> # replacement is True.
1317
+ >>> output3 = ops.multinomial(input1, 10)
1318
+ >>> # print(output3)
1319
+ >>> # [0 0 1 0 0 0 0 0 0 0]
1320
+ >>> print(len(output3))
1321
+ 10
1322
+ >>> # case 3: The output is random, and the length of the output is the same as num_sample.
1323
+ >>> # replacement is True.
1324
+ >>> # rank is 2
1325
+ >>> input4 = Tensor([[90, 10, 0], [10, 90, 0]], mstype.float32)
1326
+ >>> output4 = ops.multinomial(input4, 10)
1327
+ >>> # print(output4)
1328
+ >>> # [[0 0 0 0 0 0 0 0 1 0]
1329
+ >>> # [1 1 1 1 1 0 1 1 1 1]]
1249
1330
  """
1250
- shape = P.Shape()
1251
- reshape = P.Reshape()
1252
- const_utils.check_valid_dim(len(shape(input)), "multinomial")
1331
+ shape = _get_cache_prim(P.Shape)()
1332
+ reshape = _get_cache_prim(P.Reshape)()
1333
+
1334
+ def _check_valid_dim(dim, name):
1335
+ if dim not in (1, 2):
1336
+ raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
1337
+
1338
+ _check_valid_dim(len(shape(input)), "multinomial")
1253
1339
  seed1, seed2 = _get_seed(seed, "multinomial")
1254
1340
  if not replacement:
1255
1341
  if shape(input)[-1] < num_samples:
1256
- const_utils.raise_value_error("For 'multinomial', the 'num_samples' must be less than "
1257
- "the last dimension of input without 'replacement', "
1258
- "but got 'num_samples': {} and "
1259
- "'replacement': {}".format(num_samples, replacement))
1342
+ const_utils.raise_value_error(f"For 'multinomial', the 'num_samples' must be less than "
1343
+ f"the last dimension of input without 'replacement', "
1344
+ f"but got 'num_samples': {num_samples} and "
1345
+ f"'replacement': {replacement}")
1260
1346
  n_dist = 1
1261
1347
  if len(shape(input)) > 1:
1262
1348
  n_dist = shape(input)[-2]
1263
- random_uniform = P.UniformReal(seed1, seed2)((n_dist * shape(input)[-1],))
1349
+ random_uniform_real = P.UniformReal(seed1, seed2)
1350
+ random_cache_op = _set_prim_op_user_data(random_uniform_real, "random_cache", False)
1351
+ random_uniform = random_cache_op((n_dist * shape(input)[-1],))
1264
1352
  if n_dist != 1:
1265
1353
  random_uniform = reshape(random_uniform, (n_dist, shape(input)[-1]))
1266
- vals = P.RealDiv()(P.Log()(random_uniform), input + 1e-6)
1267
- _, indices = P.TopK()(vals, num_samples)
1354
+ real_div = _get_cache_prim(P.RealDiv)()
1355
+ log = _get_cache_prim(P.Log)()
1356
+ top_k = _get_cache_prim(P.TopK)()
1357
+
1358
+ vals = real_div(log(random_uniform), input + 1e-6)
1359
+ _, indices = top_k(vals, num_samples)
1268
1360
  return indices
1269
- return P.Multinomial(seed1, seed2)(input, num_samples)
1361
+ random_nomial = P.Multinomial(seed1, seed2)
1362
+ random_nomial = _set_prim_op_user_data(random_nomial, "random_cache", False)
1363
+ return random_nomial(input, num_samples)
1270
1364
 
1271
1365
 
1272
1366
  def _check_shape(input_shape):
1273
1367
  """Check 'shape' value."""
1274
1368
  if not isinstance(input_shape, tuple):
1275
- const_utils.raise_type_error("Type of 'shape' must be tuple, but got: {}".format(type(input_shape)))
1369
+ const_utils.raise_type_error(f"Type of 'shape' must be tuple, but got: {type(input_shape)}")
1276
1370
  for item in input_shape:
1277
1371
  if not isinstance(item, int):
1278
- const_utils.raise_type_error("Elements of 'shape' must be int, but got: {}".format(type(item)))
1372
+ const_utils.raise_type_error(f"Elements of 'shape' must be int, but got: {type(item)}")
1279
1373
  if item < 1:
1280
- const_utils.raise_value_error("Elements of 'shape' must be positive int, but got: {}".format(item))
1374
+ const_utils.raise_value_error(f"Elements of 'shape' must be positive int, but got: {item}")
1281
1375
  return True
1282
1376
 
1283
1377
 
1284
1378
  def _check_param(op_name, param_name, param_value):
1285
1379
  """Check type of param_value is Tensor, int, or float."""
1286
- if not isinstance(param_value, (Tensor, int, float)):
1380
+ if not isinstance(param_value, (Tensor, int, float, np.ndarray)):
1287
1381
  const_utils.raise_type_error("For '{}', the type of '{}' must be Tensor, int, or float, "
1288
1382
  "but got: {}".format(op_name, param_name, type(param_value)))
1289
1383
  return True