mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  """Defines parameter operators with functional form."""
16
16
 
17
17
  from __future__ import absolute_import
18
+ import numpy as np
18
19
 
19
20
  from mindspore import context
20
21
  from mindspore.ops import operations as P
@@ -212,13 +213,13 @@ def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False)
212
213
  """
213
214
  if not isinstance(seed, Tensor):
214
215
  if not isinstance(seed, int):
215
- raise TypeError("For multinomial_with_replacement,",
216
- "the input[seed] must be int, but got {}.".format(type(seed)))
216
+ raise TypeError(f"For multinomial_with_replacement,",
217
+ f"the input[seed] must be int, but got {type(seed)}.")
217
218
  seed = Tensor(seed, dtype=mstype.int64)
218
219
  if not isinstance(offset, Tensor):
219
220
  if not isinstance(offset, int):
220
- raise TypeError("For multinomial_with_replacement,",
221
- "the input[offset] must be int, but got {}.".format(type(offset)))
221
+ raise TypeError(f"For multinomial_with_replacement,",
222
+ f"the input[offset] must be int, but got {type(offset)}.")
222
223
  offset = Tensor(offset, dtype=mstype.int64)
223
224
  multinomial_with_replacement_ = P.MultinomialWithReplacement(numsamples=numsamples,
224
225
  replacement=replacement)
@@ -359,7 +360,7 @@ def uniform_candidate_sampler(true_classes,
359
360
  If unique=True, candidates are drawn without replacement, else unique=False with replacement.
360
361
 
361
362
  Args:
362
- true_classes (Tensor): A Tensor. The target classes with a Tensor shape of :math:`(batch_size, num_true)` .
363
+ true_classes (Tensor): A Tensor. The target classes with a Tensor shape of :math:`(batch\_size, num\_true)` .
363
364
  num_true (int): The number of target classes in each training example.
364
365
  num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
365
366
  of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
@@ -367,15 +368,17 @@ def uniform_candidate_sampler(true_classes,
367
368
  range_max (int): The number of possible classes, must be positive.
368
369
  seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
369
370
  the seed will be replaced with a randomly generated value. Default: ``0`` .
370
- remove_accidental_hits (bool): Whether accidental hit is removed. Default: ``False`` .
371
+ remove_accidental_hits (bool): Whether accidental hit is removed.
372
+ Accidental hit is when one of the true classes matches one of the sample classes.
373
+ Set ``True`` to remove which accidentally sampling the true class as sample class. Default: ``False`` .
371
374
 
372
375
  Returns:
373
376
  - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
374
- Shape: :math:`(num_sampled, )` .
377
+ Shape: :math:`(num\_sampled, )` .
375
378
  - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
376
- of true_classes. Shape: :math:`(batch_size, num_true)` .
379
+ of true_classes. Shape: :math:`(batch\_size, num\_true)` .
377
380
  - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
378
- each of sampled_candidates. Shape: :math:`(num_sampled, )` .
381
+ each of sampled_candidates. Shape: :math:`(num\_sampled, )` .
379
382
 
380
383
  Raises:
381
384
  TypeError: If neither `num_true` nor `num_sampled` is an int.
@@ -679,17 +682,15 @@ def normal(shape, mean, stddev, seed=None):
679
682
  Args:
680
683
  shape (tuple): The shape of random tensor to be generated.
681
684
  The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
682
- mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak,
683
- with data type in [int8, int16, int32, int64, float16, float32].
684
- stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0,
685
- with data type in [int8, int16, int32, int64, float16, float32].
685
+ mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak.
686
+ stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0.
686
687
  seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
687
688
  The value must be non-negative. Default: ``None`` , which will be treated as 0.
688
689
 
689
690
  Returns:
690
691
  Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
691
692
  of `mean` and `stddev`.
692
- The dtype is float32.
693
+ The dtype is [float32, float64].
693
694
 
694
695
  Supported Platforms:
695
696
  ``Ascend`` ``GPU`` ``CPU``
@@ -726,10 +727,6 @@ def normal(shape, mean, stddev, seed=None):
726
727
  mean = Tensor(mean)
727
728
  if not isinstance(stddev, Tensor):
728
729
  stddev = Tensor(stddev)
729
- mean_dtype = F.dtype(mean)
730
- stddev_dtype = F.dtype(stddev)
731
- const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
732
- const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
733
730
  seed1, seed2 = _get_seed(seed, "normal")
734
731
  stdnormal = P.StandardNormal(seed1, seed2)
735
732
  stdnormal = _set_prim_op_user_data(stdnormal, "random_cache", False)
@@ -840,26 +837,24 @@ def gamma(shape, alpha, beta, seed=None):
840
837
  >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
841
838
  >>> beta = Tensor(np.array([1.0, 2]), mindspore.float32)
842
839
  >>> output = ops.gamma(shape, alpha, beta, seed=5)
843
- >>> result = output.shape
844
840
  >>> print(output)
845
- [[[ 2.2132034 5.8855834]]
846
- [ 3.3981476 7.5805717]
847
- [[ 3.3981476 7.5805717]]
848
- [ 3.7190282 19.941492]
849
- [[ 2.9512358 2.5969937]]
850
- [ 3.786061 5.160872 ]]]
841
+ [[[ 2.2132034 5.8855834]
842
+ [ 3.8825176 8.6066265]]
843
+ [[ 3.3981476 7.5805717]
844
+ [ 3.7190282 19.941492 ]]
845
+ [[ 2.9512358 2.5969937]
846
+ [ 3.786061 5.160872 ]]]
851
847
  >>> # case 4: beta_shape is (2, 1), the output is different.
852
848
  >>> shape = (3, 1, 2)
853
849
  >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
854
850
  >>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32)
855
851
  >>> output = ops.gamma(shape, alpha, beta, seed=5)
856
- >>> result = output.shape
857
852
  >>> print(output)
858
- [[[ 5.6085486 7.8280783]]
859
- [ 15.97684 16.116285]
860
- [[ 1.8347423 1.713663]]
861
- [ 3.2434065 15.667398]
862
- [[ 4.2922077 7.3365674]]
853
+ [[[ 5.6085486 7.8280783]
854
+ [ 15.97684 16.116285]]
855
+ [[ 1.8347423 1.713663]
856
+ [ 3.2434065 15.667398]]
857
+ [[ 4.2922077 7.3365674]
863
858
  [ 5.3876944 13.159832 ]]]
864
859
  """
865
860
  seed1, seed2 = _get_seed(seed, "gamma")
@@ -1249,9 +1244,34 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1249
1244
  Returns a tensor sampled from the multinomial probability distribution located in the corresponding
1250
1245
  row of the input tensor.
1251
1246
 
1247
+ The polynomial distribution is a probability distribution that generalizes the binomial distribution formula to
1248
+ multiple states. In the polynomial distribution, each event has a fixed probability, and the sum of these
1249
+ probabilities is 1. The purpose of the `mindspore.ops.multinomial` interface is to perform `num_samples` sampling
1250
+ on the input `input`, and the output tensor is the index of the input tensor for each sampling.
1251
+ The values in `input` represent the probability of selecting the corresponding index for each sampling.
1252
+
1253
+ Here is an extreme example for better understanding. Suppose we have an input probability tensor with
1254
+ values `Tensor([90 / 100, 10 / 100, 0], mindspore.float32)`, which means we can sample three indices,
1255
+ namely index 0, index 1, and index 2, with probabilities of 90%, 10%, and 0%, respectively. We perform n samplings,
1256
+ and the resulting sequence is the calculation result of the polynomial distribution, with a length equal to the
1257
+ number of samplings.
1258
+
1259
+ In case 1 of the sample code, we perform two non-replacement samplings (`replacement` is `False`).
1260
+ The calculation result is most likely `[0, 1]`, and less likely `[1, 0]`. Since the probability of selecting
1261
+ index 0 is 90% for each sampling, the first result is most likely to be index 0. Since the probability of selecting
1262
+ index 2 is 0, index 2 cannot appear in the sampling result. Therefore, the second result must be index 1,
1263
+ and the resulting sequence is `[0, 1]`.
1264
+
1265
+ In case 2 of the sample code, we perform 10 replacement samplings (`replacement` is `True`).
1266
+ As expected, about 90% of the sampling results are index 0.
1267
+
1268
+ In case 3 of the sample code, we extend the input to 2 dimensions, and the sampling results
1269
+ in each dimension also match our sampling expectations.
1270
+
1252
1271
  Note:
1253
1272
  The rows of input do not need to sum to one (in which case we use the values as weights),
1254
- but must be non-negative, finite and have a non-zero sum.
1273
+ but must be non-negative, finite and have a non-zero sum. When using values as weights, it can be understood as
1274
+ normalizing the input along the last dimension.
1255
1275
 
1256
1276
  Args:
1257
1277
  input (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
@@ -1278,27 +1298,35 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1278
1298
  >>> from mindspore import Tensor, ops
1279
1299
  >>> from mindspore import dtype as mstype
1280
1300
  >>> # case 1: The output is random, and the length of the output is the same as num_sample.
1281
- >>> input = Tensor([0, 9, 4, 0], mindspore.float32)
1282
- >>> output = ops.multinomial(input, 2)
1283
- >>> # print(output)
1284
- >>> # [1 2] or [2 1]
1285
- >>> # the case where the result is [2 1] in multiple times.
1286
- >>> # This is because the value corresponding to the index 1 is larger than the value of the index 2.
1287
- >>> print(len(output))
1301
+ >>> # replacement is False.
1302
+ >>> input1 = Tensor([90 / 100, 10 / 100, 0], mindspore.float32)
1303
+ >>> input2 = Tensor([90, 10, 0], mindspore.float32)
1304
+ >>> # input1 and input2 have the same meaning.
1305
+ >>> output1 = ops.multinomial(input1, 2, replacement=False)
1306
+ >>> output2 = ops.multinomial(input2, 2, replacement=False)
1307
+ >>> # print(output1)
1308
+ >>> # [0 1]
1309
+ >>> # print(output2)
1310
+ >>> # [0 1]
1311
+ >>> print(len(output1))
1312
+ 2
1313
+ >>> print(len(output2))
1288
1314
  2
1289
1315
  >>> # case 2: The output is random, and the length of the output is the same as num_sample.
1290
- >>> # replacement is False(Default).
1291
- >>> # If the extracted value is 0, the index value of 1 will be returned.
1292
- >>> input = Tensor([0, 9, 4, 0], mstype.float32)
1293
- >>> output = ops.multinomial(input, 4)
1294
- >>> print(output)
1295
- [1 1 2 1]
1296
- >>> # case 3: The output is random, num_sample == x_length = 4, and replacement is True,
1297
- >>> # Can extract the same elements。
1298
- >>> input = Tensor([0, 9, 4, 0], mstype.float32)
1299
- >>> output = ops.multinomial(input, 4, True)
1300
- >>> print(output)
1301
- [1 1 2 2]
1316
+ >>> # replacement is True.
1317
+ >>> output3 = ops.multinomial(input1, 10)
1318
+ >>> # print(output3)
1319
+ >>> # [0 0 1 0 0 0 0 0 0 0]
1320
+ >>> print(len(output3))
1321
+ 10
1322
+ >>> # case 3: The output is random, and the length of the output is the same as num_sample.
1323
+ >>> # replacement is True.
1324
+ >>> # rank is 2
1325
+ >>> input4 = Tensor([[90, 10, 0], [10, 90, 0]], mstype.float32)
1326
+ >>> output4 = ops.multinomial(input4, 10)
1327
+ >>> # print(output4)
1328
+ >>> # [[0 0 0 0 0 0 0 0 1 0]
1329
+ >>> # [1 1 1 1 1 0 1 1 1 1]]
1302
1330
  """
1303
1331
  shape = _get_cache_prim(P.Shape)()
1304
1332
  reshape = _get_cache_prim(P.Reshape)()
@@ -1318,7 +1346,9 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1318
1346
  n_dist = 1
1319
1347
  if len(shape(input)) > 1:
1320
1348
  n_dist = shape(input)[-2]
1321
- random_uniform = _get_cache_prim(P.UniformReal)(seed1, seed2)((n_dist * shape(input)[-1],))
1349
+ random_uniform_real = P.UniformReal(seed1, seed2)
1350
+ random_cache_op = _set_prim_op_user_data(random_uniform_real, "random_cache", False)
1351
+ random_uniform = random_cache_op((n_dist * shape(input)[-1],))
1322
1352
  if n_dist != 1:
1323
1353
  random_uniform = reshape(random_uniform, (n_dist, shape(input)[-1]))
1324
1354
  real_div = _get_cache_prim(P.RealDiv)()
@@ -1336,18 +1366,18 @@ def multinomial(input, num_samples, replacement=True, seed=None):
1336
1366
  def _check_shape(input_shape):
1337
1367
  """Check 'shape' value."""
1338
1368
  if not isinstance(input_shape, tuple):
1339
- const_utils.raise_type_error("Type of 'shape' must be tuple, but got: {}".format(type(input_shape)))
1369
+ const_utils.raise_type_error(f"Type of 'shape' must be tuple, but got: {type(input_shape)}")
1340
1370
  for item in input_shape:
1341
1371
  if not isinstance(item, int):
1342
- const_utils.raise_type_error("Elements of 'shape' must be int, but got: {}".format(type(item)))
1372
+ const_utils.raise_type_error(f"Elements of 'shape' must be int, but got: {type(item)}")
1343
1373
  if item < 1:
1344
- const_utils.raise_value_error("Elements of 'shape' must be positive int, but got: {}".format(item))
1374
+ const_utils.raise_value_error(f"Elements of 'shape' must be positive int, but got: {item}")
1345
1375
  return True
1346
1376
 
1347
1377
 
1348
1378
  def _check_param(op_name, param_name, param_value):
1349
1379
  """Check type of param_value is Tensor, int, or float."""
1350
- if not isinstance(param_value, (Tensor, int, float)):
1380
+ if not isinstance(param_value, (Tensor, int, float, np.ndarray)):
1351
1381
  const_utils.raise_type_error("For '{}', the type of '{}' must be Tensor, int, or float, "
1352
1382
  "but got: {}".format(op_name, param_name, type(param_value)))
1353
1383
  return True
@@ -423,7 +423,7 @@ def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
423
423
 
424
424
  valid_indices_dtype = [mstype.int32, mstype.int64]
425
425
  if row_pointers.dtype in valid_indices_dtype and col_indices.dtype in valid_indices_dtype:
426
- if row_pointers.dtype == mstype.int64 or col_indices.dtype == mstype.int64:
426
+ if mstype.int64 in (row_pointers.dtype, col_indices.dtype):
427
427
  return csr_sparse_matrix_to_dense(dense_shape.astype(mstype.int64), batch_pointers.astype(mstype.int64),
428
428
  row_pointers.astype(mstype.int64), col_indices.astype(mstype.int64),
429
429
  values)
@@ -383,7 +383,7 @@ def coo_relu(x: COOTensor) -> COOTensor:
383
383
  Args:
384
384
  x (COOTensor): Input COOTensor with shape :math:`(N, *)`, where :math:`*`
385
385
  means any number of additional dimensions. Its dtype is
386
- `number <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.html#mindspore.dtype>`_.
386
+ `number <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.html#mindspore.dtype>`_.
387
387
 
388
388
  Returns:
389
389
  COOTensor, has the same shape and dtype as the `x`.
@@ -1980,7 +1980,7 @@ def coo_acosh(x: COOTensor) -> COOTensor:
1980
1980
 
1981
1981
  .. math::
1982
1982
 
1983
- out_i = \cosh^{-1}(input_i)
1983
+ y_i = \cosh^{-1}(x_i)
1984
1984
 
1985
1985
  .. warning::
1986
1986
  Given an input COOTensor x, the function computes inverse hyperbolic cosine of every element.
@@ -2414,14 +2414,16 @@ def coo_sin(x: COOTensor) -> COOTensor:
2414
2414
  return COOTensor(x.indices, math_func.sin(x.values), x.shape)
2415
2415
 
2416
2416
 
2417
- __all__ = ["csr_cos", "csr_tan", "csr_exp", "csr_inv", "csr_relu", "csr_expm1", "csr_isfinite",
2418
- "csr_asin", "csr_sqrt", "csr_log", "csr_isnan", "csr_acos", "csr_floor", "csr_atan",
2419
- "csr_square", "csr_relu6", "csr_sinh", "csr_ceil", "csr_cosh", "csr_softsign",
2420
- "csr_log1p", "csr_round", "csr_tanh", "csr_asinh", "csr_neg", "csr_acosh", "csr_isinf",
2421
- "csr_atanh", "csr_sigmoid", "csr_abs", "csr_sin", "coo_cos", "coo_tan", "coo_exp",
2422
- "coo_inv", "coo_relu", "coo_expm1", "coo_isfinite", "coo_asin", "coo_sqrt", "coo_log",
2423
- "coo_isnan", "coo_acos", "coo_floor", "coo_atan", "coo_square", "coo_relu6", "coo_sinh",
2424
- "coo_ceil", "coo_cosh", "coo_softsign", "coo_log1p", "coo_round", "coo_tanh",
2425
- "coo_asinh", "coo_neg", "coo_acosh", "coo_isinf", "coo_atanh", "coo_sigmoid", "coo_abs",
2426
- "coo_sin"]
2417
+ __all__ = [
2418
+ "csr_cos", "csr_tan", "csr_exp", "csr_inv", "csr_relu", "csr_expm1", "csr_isfinite",
2419
+ "csr_asin", "csr_sqrt", "csr_log", "csr_isnan", "csr_acos", "csr_floor", "csr_atan",
2420
+ "csr_square", "csr_relu6", "csr_sinh", "csr_ceil", "csr_cosh", "csr_softsign",
2421
+ "csr_log1p", "csr_round", "csr_tanh", "csr_asinh", "csr_neg", "csr_acosh", "csr_isinf",
2422
+ "csr_atanh", "csr_sigmoid", "csr_abs", "csr_sin", "coo_cos", "coo_tan", "coo_exp",
2423
+ "coo_inv", "coo_relu", "coo_expm1", "coo_isfinite", "coo_asin", "coo_sqrt", "coo_log",
2424
+ "coo_isnan", "coo_acos", "coo_floor", "coo_atan", "coo_square", "coo_relu6", "coo_sinh",
2425
+ "coo_ceil", "coo_cosh", "coo_softsign", "coo_log1p", "coo_round", "coo_tanh",
2426
+ "coo_asinh", "coo_neg", "coo_acosh", "coo_isinf", "coo_atanh", "coo_sigmoid", "coo_abs",
2427
+ "coo_sin"
2428
+ ]
2427
2429
  __all__.sort()
@@ -26,9 +26,11 @@ def vmap(fn, in_axes=0, out_axes=0):
26
26
 
27
27
  Vmap is pioneered by Jax and it removes the restriction of batch dimension on the operator, and provides a
28
28
  more convenient and unified operator expression. Moreover, it allows users to composite with other functional
29
- modules such as :func:`mindspore.grad`, to improve the development efficiency. In addition, the vectorizing
30
- map does not execute loops outside the function, but sinks loops into the primitive operations of the function
31
- for better performance. When combined with `Graph Kernel Fusion`, operational efficiency would be further improved.
29
+ modules such as :func:`mindspore.grad`, to improve the development efficiency, please refer to the
30
+ `Automatic Vectorization (Vmap) <https://www.mindspore.cn/tutorials/experts/en/r2.2/vmap/vmap.html>`_ tutorial
31
+ for more detail. In addition, the vectorizing map does not execute loops outside the function, but sinks loops
32
+ into the primitive operations of the function for better performance. When combined with `Graph Kernel Fusion`,
33
+ operational efficiency would be further improved.
32
34
 
33
35
  .. warning::
34
36
  This is an experimental API that is subject to change or deletion.
@@ -36,8 +38,7 @@ def vmap(fn, in_axes=0, out_axes=0):
36
38
  Note:
37
39
  1. The power of vmap comes from the implementation of VmapRules of primitives. Although we have designed a
38
40
  generalized rule for user custom operators, we can not guarantee that it works well for all operators,
39
- please be aware the risk of use. If you want to achieve a better performance, please refer to the tutorial to
40
- implement the specific VmapRule for the custom operator, which won't take too much time.
41
+ unknown exceptions may occur, please be aware the risk of use.
41
42
  2. When calling the random number generation methods within the scope of vmap, the same random number is
42
43
  generated among vector functions each time. If you expect each vector branch to use different random numbers,
43
44
  you need to generate batch random numbers externally in advance and then transfer them to vmap.
@@ -20,6 +20,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
20
20
  from mindspore.ops import _constants
21
21
  from mindspore.ops.function import *
22
22
  from mindspore.ops.function.array_func import narrow, flatten
23
+ from mindspore.ops.function.math_func import all
23
24
  from mindspore.ops import operations as P
24
25
  from mindspore.ops.operations import array_ops
25
26
  from mindspore.ops.primitive import Primitive
@@ -122,15 +123,16 @@ reduced_shape = Primitive("reduced_shape")
122
123
  # shape_mul:input must be shape multiply elements in tuple(shape)
123
124
  shape_mul = _sequence_ops.shape_mul()
124
125
 
126
+ tensor_operator_registry.register('tuple_to_tensor', _sequence_ops.TupleToTensor)
125
127
  tensor_operator_registry.register('add', P.Add)
126
128
  tensor_operator_registry.register('addr', addr)
127
129
  tensor_operator_registry.register('addcdiv', P.Addcdiv)
128
130
  tensor_operator_registry.register('addcmul', P.Addcmul)
129
- tensor_operator_registry.register('all', P.ReduceAll)
131
+ tensor_operator_registry.register('all', all)
130
132
  tensor_operator_registry.register('angle', angle)
131
133
  tensor_operator_registry.register('any', P.ReduceAny)
132
134
  tensor_operator_registry.register('atan2', atan2)
133
- tensor_operator_registry.register('abs', P.Abs)
135
+ tensor_operator_registry.register('abs', abs)
134
136
  tensor_operator_registry.register('baddbmm', baddbmm)
135
137
  tensor_operator_registry.register('geqrf', geqrf)
136
138
  tensor_operator_registry.register('histc', histc)
@@ -142,6 +144,7 @@ tensor_operator_registry.register('slogdet', slogdet)
142
144
  tensor_operator_registry.register('trace', trace)
143
145
  tensor_operator_registry.register('tril', tril)
144
146
  tensor_operator_registry.register('chunk', chunk)
147
+ tensor_operator_registry.register('count_nonzero', count_nonzero)
145
148
  tensor_operator_registry.register('sqrt', sqrt)
146
149
  tensor_operator_registry.register('square', square)
147
150
  tensor_operator_registry.register('sub', sub)
@@ -163,10 +166,10 @@ tensor_operator_registry.register('negative', neg)
163
166
  tensor_operator_registry.register('amin', amin)
164
167
  tensor_operator_registry.register('amax', amax)
165
168
  tensor_operator_registry.register('aminmax', aminmax)
166
- tensor_operator_registry.register('mean', P.ReduceMean)
169
+ tensor_operator_registry.register('mean', mean)
167
170
  tensor_operator_registry.register('prod', prod)
168
171
  tensor_operator_registry.register('round', P.Round)
169
- tensor_operator_registry.register('reshape', P.Reshape)
172
+ tensor_operator_registry.register('reshape', reshape)
170
173
  tensor_operator_registry.register('reverse', P.ReverseV2)
171
174
  tensor_operator_registry.register('reverse_sequence', P.ReverseSequence)
172
175
  tensor_operator_registry.register('xlogy', P.Xlogy)
@@ -176,8 +179,8 @@ tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
176
179
  tensor_operator_registry.register('matmul', matmul)
177
180
  tensor_operator_registry.register('inner', inner)
178
181
  tensor_operator_registry.register('xdivy', P.Xdivy)
179
- tensor_operator_registry.register('argmax', P.Argmax)
180
- tensor_operator_registry.register('argmin', P.Argmin)
182
+ tensor_operator_registry.register('argmax', argmax)
183
+ tensor_operator_registry.register('argmin', argmin)
181
184
  tensor_operator_registry.register('cumsum', P.CumSum)
182
185
  tensor_operator_registry.register('cummin', cummin)
183
186
  tensor_operator_registry.register('cummax', cummax)
@@ -216,10 +219,11 @@ tensor_operator_registry.register('logdet', logdet)
216
219
  tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
217
220
  tensor_operator_registry.register('matrix_determinant', matrix_determinant)
218
221
  tensor_operator_registry.register('ceil', P.Ceil)
219
- tensor_operator_registry.register('fill', P.Fill)
220
- tensor_operator_registry.register('tile', P.Tile)
222
+ tensor_operator_registry.register('fillv2', P.FillV2)
223
+ tensor_operator_registry.register('tile', tile)
221
224
  tensor_operator_registry.register('logit', logit)
222
- tensor_operator_registry.register('sum', P.ReduceSum)
225
+ tensor_operator_registry.register('sum', sum)
226
+ tensor_operator_registry.register('reducesum', P.ReduceSum)
223
227
  tensor_operator_registry.register('split', split)
224
228
  tensor_operator_registry.register('tensor_split', tensor_split)
225
229
  tensor_operator_registry.register('vsplit', vsplit)
@@ -264,7 +268,7 @@ tensor_operator_registry.register('standard_normal', P.StandardNormal)
264
268
  tensor_operator_registry.register('sigmoid', P.Sigmoid)
265
269
  tensor_operator_registry.register('median', Median)
266
270
  tensor_operator_registry.register('tanh', tanh)
267
- tensor_operator_registry.register('exp', P.Exp)
271
+ tensor_operator_registry.register('exp', exp)
268
272
  tensor_operator_registry.register('addbmm', addbmm)
269
273
  tensor_operator_registry.register('addmm', addmm)
270
274
  tensor_operator_registry.register('addmv', addmv)
@@ -284,6 +288,7 @@ tensor_operator_registry.register('ldexp', ldexp)
284
288
  tensor_operator_registry.register('clamp', clamp)
285
289
  tensor_operator_registry.register('fold', fold)
286
290
  tensor_operator_registry.register('unfold', unfold)
291
+ tensor_operator_registry.register('diagonal', diagonal)
287
292
  tensor_operator_registry.register('diagonal_scatter', diagonal_scatter)
288
293
  tensor_operator_registry.register('index_add', index_add)
289
294
  tensor_operator_registry.register('greater', greater)