mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -19,22 +19,24 @@ from collections.abc import Iterable
19
19
  import numpy as np
20
20
 
21
21
  from mindspore.common import Tensor
22
+ from mindspore.common._stub_tensor import StubTensor
22
23
  from mindspore.ops import composite as C
23
24
  from mindspore.ops.operations.array_ops import Cast
24
25
  from mindspore.ops.operations._scalar_ops import bit_or, bit_and
26
+ from mindspore.ops.operations.comm_ops import ReduceOp
25
27
  from mindspore.ops import signature as sig
26
28
  from mindspore.ops.operations.math_ops import _infer_shape_reduce
27
- from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
28
- from mindspore import context
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,\
30
+ _run_op, _check_contains_variable
29
31
  from mindspore._c_expression import Tensor as Tensor_
30
32
  from mindspore._c_expression import typing
31
33
  from mindspore import _checkparam as validator
32
34
  from mindspore.common import dtype as mstype
33
35
  from mindspore.common.parameter import Parameter
34
- from mindspore.communication.management import GlobalComm
36
+ from mindspore.communication.management import GlobalComm, get_rank
35
37
  from mindspore.common.api import _pynative_executor
36
38
  from mindspore.common._register_for_adapter import ms_adapter_registry
37
-
39
+ from mindspore import ops
38
40
 
39
41
  # Bit operation
40
42
  bit_and = bit_and()
@@ -73,12 +75,11 @@ class ExtractImagePatches(Primitive):
73
75
  - valid: Means that the taken patch area must be completely covered in the original image.
74
76
 
75
77
  Inputs:
76
- - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
77
- data type is number.
78
+ - **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(in\_batch, in\_depth, in\_row, in\_col)`.
78
79
 
79
80
  Outputs:
80
- Tensor, a 4-D tensor whose data type is same as 'input_x',
81
- and the shape is [out_batch, out_depth, out_row, out_col], Where the out_batch is the same as the in_batch
81
+ Tensor, a 4-D tensor whose data type is same as 'input_x', and the shape
82
+ is :math:`(out\_batch, out\_depth, out\_row, out\_col)`,where the out_batch is the same as the in_batch
82
83
  and
83
84
 
84
85
  .. math::
@@ -121,7 +122,6 @@ class ExtractImagePatches(Primitive):
121
122
  validator.check_value_type('padding', padding, [str], self.name)
122
123
  self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
123
124
  self.add_prim_attr("padding", self.padding)
124
- self.is_ge = context.get_context("enable_ge")
125
125
 
126
126
 
127
127
  class Quant(PrimitiveWithInfer):
@@ -144,7 +144,7 @@ class Quant(PrimitiveWithInfer):
144
144
  Args:
145
145
  scale (float) : Specifies the scaling ratio.
146
146
  offset (float): Specifies the offset.
147
- sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
147
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
148
148
  round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
149
149
  Default: "Round".
150
150
 
@@ -172,7 +172,7 @@ class Quant(PrimitiveWithInfer):
172
172
  return x_shape
173
173
 
174
174
  def infer_dtype(self, x_type):
175
- validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
175
+ validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
176
176
  validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
177
177
  return mstype.int8
178
178
 
@@ -254,8 +254,8 @@ class Dequant(PrimitiveWithInfer):
254
254
  This operation only support Ascend 310 inference environment.
255
255
 
256
256
  Args:
257
- sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
258
- relu_flag (bool): Specifies whether to perform ReLU. Default: False.
257
+ sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
258
+ relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
259
259
 
260
260
  Inputs:
261
261
  - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
@@ -281,7 +281,7 @@ class Dequant(PrimitiveWithInfer):
281
281
  return x_shape
282
282
 
283
283
  def infer_dtype(self, x_type, deq_scale_type):
284
- validator.check_subclass("x", x_type, mstype.tensor, self.name)
284
+ validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
285
285
  validator.check_type_name("x", x_type, [mstype.int32], self.name)
286
286
  validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
287
287
  return mstype.float16
@@ -502,6 +502,109 @@ class Receive(PrimitiveWithInfer):
502
502
  return self.get_attr_dict()['dtype']
503
503
 
504
504
 
505
+ class Reduce(PrimitiveWithInfer):
506
+ """
507
+ Reduces tensor across the processes in the specified communication group.
508
+
509
+ Note:
510
+ Only process with destination rank receives the reduced output.
511
+ Other processes only get a tensor with shape [1], which has no mathematical meaning.
512
+
513
+ Args:
514
+ dest_rank (int): Specifies the rank of the process that receives the reduced output.
515
+ op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
516
+ On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
517
+ group (str, optional): The communication group to work on.
518
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
519
+
520
+ Inputs:
521
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
522
+
523
+ Examples:
524
+ >>> import mindspore.ops as ops
525
+ >>> import mindspore.nn as nn
526
+ >>> from mindspore.communication import init
527
+ >>> from mindspore import Tensor
528
+ >>> import numpy as np
529
+ >>> # Launch 4 processes.
530
+ >>> init()
531
+ >>> class ReduceNet(nn.Cell):
532
+ >>> def __init__(self):
533
+ >>> super(Net, self).__init__()
534
+ >>> self.reduce = ops.Reduce(dest_rank=1)
535
+ >>>
536
+ >>> def construct(self, x):
537
+ >>> out = self.reduce(x)
538
+ >>> return out
539
+ >>> input = Tensor(np.ones([2, 8]).astype(np.float32))
540
+ >>> net = ReduceNet()
541
+ >>> output = net(input)
542
+ >>> print(output)
543
+ Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
544
+ [4. 4. 4. 4. 4. 4. 4. 4.]],
545
+ Other proesses: [0.].
546
+ """
547
+
548
+ @prim_attr_register
549
+ def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
550
+ self.dest_rank = dest_rank
551
+ self.op = op
552
+ self.group = group
553
+
554
+ def infer_shape(self, x_shape):
555
+ # The process with dest_rank returns the reduced output.
556
+ # Other processes only gets a tensor with shape [1], which has no mathematical meaning.
557
+ if self.dest_rank == get_rank():
558
+ return x_shape
559
+ return [1]
560
+
561
+ def infer_dtype(self, x_dtype):
562
+ return x_dtype
563
+
564
+
565
+ class Barrier(PrimitiveWithInfer):
566
+ """
567
+ Synchronizes all processes in the specified group.
568
+
569
+ Note:
570
+ After calling this collective operator,
571
+ this process will be blocked until all other processes in the group call this operator.
572
+
573
+ Args:
574
+ group (str, optional): The communication group to work on.
575
+ Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
576
+
577
+ Examples:
578
+ >>> import mindspore.ops as ops
579
+ >>> import mindspore.nn as nn
580
+ >>> from mindspore.communication import init
581
+ >>> from mindspore import Tensor
582
+ >>> import numpy as np
583
+ >>> # Launch 4 processes.
584
+ >>> init()
585
+ >>> class BarrierNet(nn.Cell):
586
+ >>> def __init__(self):
587
+ >>> super(Net, self).__init__()
588
+ >>> self.barrier = ops.Barrier()
589
+ >>>
590
+ >>> def construct(self):
591
+ >>> self.barrier()
592
+ >>> net = BarrierNet()
593
+ >>> net()
594
+ """
595
+
596
+ @prim_attr_register
597
+ def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
598
+ self.group = group
599
+ self.add_prim_attr("side_effect_mem", True)
600
+
601
+ def infer_shape(self):
602
+ return [1]
603
+
604
+ def infer_dtype(self):
605
+ return mstype.float32
606
+
607
+
505
608
  class MatrixSetDiag(PrimitiveWithInfer):
506
609
  r"""
507
610
  Modifies the batched diagonal part of a batched tensor.
@@ -604,9 +707,9 @@ class ConfusionMulGrad(PrimitiveWithInfer):
604
707
  return outshape0, outshape1
605
708
 
606
709
  def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
607
- validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name)
608
- validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
609
- validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
710
+ validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
711
+ validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
712
+ validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
610
713
  return input0_dtype, input1_dtype
611
714
 
612
715
 
@@ -619,7 +722,7 @@ class ConvertToDynamic(PrimitiveWithCheck):
619
722
 
620
723
  Args:
621
724
  is_dynamic_rank (bool): If true, convert to dynamic rank.
622
- If false, convert to dynamic shape. Default: False.
725
+ If false, convert to dynamic shape. Default: ``False``.
623
726
 
624
727
  Inputs:
625
728
  - **input** (Tensor) - The tensor used for testing.
@@ -664,7 +767,7 @@ class ConvertToDynamic(PrimitiveWithCheck):
664
767
  validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
665
768
 
666
769
  def check_dtype(self, input_dtype):
667
- validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
770
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
668
771
 
669
772
 
670
773
  class GpuConvertToDynamicShape(PrimitiveWithCheck):
@@ -714,7 +817,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck):
714
817
  validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
715
818
 
716
819
  def check_dtype(self, input_dtype):
717
- validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
820
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
718
821
 
719
822
 
720
823
  class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
@@ -766,7 +869,7 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
766
869
 
767
870
  def infer_type(self, input_dtype):
768
871
  """Infer the dtype of input for ErrorOnDynamicShapeInput."""
769
- validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
872
+ validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
770
873
  return input_dtype
771
874
 
772
875
  def infer_value(self, input_tensor):
@@ -816,7 +919,7 @@ class SequenceMask(PrimitiveWithCheck):
816
919
  validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
817
920
 
818
921
  def check_dtype(self, lengths_dtype, maxlen_dtype):
819
- validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
922
+ validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
820
923
  validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
821
924
 
822
925
 
@@ -1170,8 +1273,8 @@ class DynamicStitch(PrimitiveWithCheck):
1170
1273
  return out_shape
1171
1274
 
1172
1275
  def check_dtype(self, indices_type, data_type):
1173
- validator.check_subclass("indices[0]", indices_type[0], mstype.tensor, self.name)
1174
- validator.check_subclass("data[0]", data_type[0], mstype.tensor, self.name)
1276
+ validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
1277
+ validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
1175
1278
  indices_num = len(indices_type)
1176
1279
  for i in range(0, indices_num):
1177
1280
  validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
@@ -1418,6 +1521,7 @@ class DecodeImage(PrimitiveWithInfer):
1418
1521
 
1419
1522
  Examples:
1420
1523
  """
1524
+
1421
1525
  @prim_attr_register
1422
1526
  def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
1423
1527
  _op_max_size=[8000000]):
@@ -1467,7 +1571,7 @@ class DynamicBroadcastTo(Primitive):
1467
1571
  Inputs:
1468
1572
  - **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
1469
1573
  float16, float32, int32, int8, uint8.
1470
- The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
1574
+ The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
1471
1575
  - **shape** (Tensor): The target shape to broadcast.
1472
1576
 
1473
1577
  Outputs:
@@ -1495,6 +1599,16 @@ class Cummin(Primitive):
1495
1599
 
1496
1600
  Refer to :func:`mindspore.ops.cummin` for more detail.
1497
1601
 
1602
+ Args:
1603
+ axis (int): The axis to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)).
1604
+
1605
+ Inputs:
1606
+ - **input** (Tensor) - The input tensor.
1607
+
1608
+ Outputs:
1609
+ A tuple of 2 Tensors(values, indices), containing the cumulative minimum of elements and the index,
1610
+ The shape of each output tensor is the same as input `input`.
1611
+
1498
1612
  Supported Platforms:
1499
1613
  ``Ascend`` ``GPU`` ``CPU``
1500
1614
 
@@ -1509,6 +1623,7 @@ class Cummin(Primitive):
1509
1623
  >>> print(output[1])
1510
1624
  [0 1 1 1 4 4]
1511
1625
  """
1626
+
1512
1627
  @prim_attr_register
1513
1628
  def __init__(self, axis):
1514
1629
  """Initialize Cummin"""
@@ -1528,7 +1643,7 @@ class DynamicResizeNearestNeighbor(Primitive):
1528
1643
 
1529
1644
  Args:
1530
1645
  align_corners (bool): Whether the centers of the 4 corner pixels of the input
1531
- and output tensors are aligned. Default: False.
1646
+ and output tensors are aligned. Default: ``False``.
1532
1647
 
1533
1648
  Inputs:
1534
1649
  - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
@@ -1613,7 +1728,7 @@ class PsROIPooling(PrimitiveWithInfer):
1613
1728
  return output_shape, output_map_shape
1614
1729
 
1615
1730
  def infer_dtype(self, inputs_type, rois_type):
1616
- map_type = mstype.tensor_type(mstype.int32)
1731
+ map_type = mstype.TensorType(mstype.int32)
1617
1732
  return inputs_type, map_type
1618
1733
 
1619
1734
 
@@ -1671,8 +1786,10 @@ class PartitionedCall(PrimitiveWithInfer):
1671
1786
 
1672
1787
  Examples:
1673
1788
  """
1789
+
1674
1790
  @prim_attr_register
1675
1791
  def __init__(self, graph, executor_type=""):
1792
+ super(PartitionedCall, self).__init__(self.__class__.__name__)
1676
1793
  self.add_prim_attr("executor_type", executor_type)
1677
1794
  self.graph = graph
1678
1795
 
@@ -1744,9 +1861,6 @@ class CellBackwardHook(PrimitiveWithInfer):
1744
1861
  def __call__(self, args):
1745
1862
  if not isinstance(args, tuple):
1746
1863
  args = (args,)
1747
- for arg in args:
1748
- if isinstance(arg, Parameter) and arg.has_init:
1749
- arg.init_data()
1750
1864
  return _run_op(self, self.name, args)
1751
1865
 
1752
1866
  def infer_shape(self, *inputs_shape):
@@ -1832,16 +1946,32 @@ class Format(PrimitiveWithInfer):
1832
1946
  def __init__(self):
1833
1947
  self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
1834
1948
 
1949
+
1835
1950
  def __infer__(self, str_, *var):
1836
- str_value = str_["value"]
1951
+ def check_variable(str_, var):
1952
+ if _check_contains_variable(str_['dtype'], str_['value']):
1953
+ return True
1954
+
1955
+ for item in var:
1956
+ if _check_contains_variable(item['dtype'], item['value']):
1957
+ return True
1958
+ return False
1959
+
1960
+
1961
+ if check_variable(str_, var):
1962
+ return {'dtype': mstype.string, 'shape': [], 'value': None}
1963
+
1964
+
1965
+ str_value = str_['value']
1966
+ kwargs = dict()
1837
1967
  var_value = list()
1838
- if str_value is None and str_["dtype"] is not None:
1839
- raise ValueError("str.format not support to input a variable.")
1968
+
1840
1969
  for item in var:
1841
- if item["value"] is None and item["dtype"] is not None:
1842
- raise ValueError("str.format not support to input a variable.")
1970
+ if isinstance(item["dtype"], typing.Keyword):
1971
+ kwargs.update(item["value"])
1843
1972
  var_value.append(item["value"])
1844
- value = str_value.format(*var_value)
1973
+
1974
+ value = str_value.format(*var_value, **kwargs)
1845
1975
  return {'dtype': mstype.string, 'shape': [], 'value': value}
1846
1976
 
1847
1977
 
@@ -1982,7 +2112,7 @@ class ClipByNorm(PrimitiveWithInfer):
1982
2112
 
1983
2113
  Args:
1984
2114
  axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
1985
- Default: None, all dimensions to calculate.
2115
+ Default: ``None``, all dimensions to calculate.
1986
2116
 
1987
2117
  Inputs:
1988
2118
  - **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
@@ -2060,8 +2190,8 @@ class TopTypeof(Primitive):
2060
2190
  'slice': mstype.Slice(),
2061
2191
  'list': mstype.List(),
2062
2192
  'tuple': mstype.Tuple(),
2063
- 'Tensor': mstype.tensor,
2064
- 'NoneType': mstype.none_type(),
2193
+ 'Tensor': mstype.tensor_type,
2194
+ 'NoneType': mstype.NoneType(),
2065
2195
  'int': mstype.Int(),
2066
2196
  'bool': mstype.Bool(),
2067
2197
  'ellipsis': mstype.Ellipsis_(),
@@ -2098,7 +2228,7 @@ class MixedPrecisionCast(Primitive):
2098
2228
  Examples:
2099
2229
  >>> import numpy as np
2100
2230
  >>> from mindspore import Tensor
2101
- >>> from mindspore.common import dtype as mstype
2231
+ >>> from mindspore import dtype as mstype
2102
2232
  >>> from mindspore.ops.operations import _inner_ops as inner
2103
2233
  >>> x = Tensor(np.ones([2, 3], dtype=np.float32))
2104
2234
  >>> out = inner.MixedPrecisionCast(mstype.float16, x)
@@ -2175,13 +2305,22 @@ class CheckBprop(PrimitiveWithInfer):
2175
2305
  raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
2176
2306
  f"the number of input arguments except 'out' and 'dout', "
2177
2307
  f"which is:{len(yshapes)} but got {len(xshapes)}.")
2178
- checking_range = len(yshapes)
2179
- for i in range(checking_range):
2180
- xshape = xshapes[i]
2181
- yshape = yshapes[i]
2308
+
2309
+ def shape_equal(shape1, shape2):
2310
+ if len(shape1) != len(shape2):
2311
+ return False
2312
+ for shape_axis1, shape_axis2 in zip(shape1, shape2):
2313
+ if shape_axis1 == -1 or shape_axis2 == -1:
2314
+ continue
2315
+ if shape_axis1 != shape_axis2:
2316
+ return False
2317
+ return True
2318
+
2319
+ for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
2182
2320
  if not xshape or not yshape:
2183
2321
  continue
2184
- if xshape != yshape:
2322
+
2323
+ if not shape_equal(xshape, yshape):
2185
2324
  raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2186
2325
  f"should have the same shape as the {i}th argument, "
2187
2326
  f"which is:{yshape}, but got: {xshape}.")
@@ -2200,18 +2339,19 @@ class CheckBprop(PrimitiveWithInfer):
2200
2339
  for i in range(checking_range):
2201
2340
  xdtype = xdtypes[i]
2202
2341
  ydtype = ydtypes[i]
2203
- if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type):
2342
+ if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
2204
2343
  continue
2205
- if isinstance(ydtype, mstype.function_type):
2206
- if not isinstance(xdtype, mstype.env_type_type):
2344
+ if isinstance(ydtype, mstype.FunctionType):
2345
+ if not isinstance(xdtype, mstype.EnvType):
2207
2346
  raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
2208
- f"should be {mstype.env_type_type}, but got {xdtype}.")
2347
+ f"should be {mstype.EnvType}, but got {xdtype}.")
2209
2348
  if xdtype != ydtype:
2210
2349
  raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2211
2350
  f"should have the same dtype as the {i}th argument, "
2212
2351
  f"which is:{ydtype}, but got: {xdtype}.")
2213
2352
  return xdtypes
2214
2353
 
2354
+
2215
2355
  check_bprop = CheckBprop()
2216
2356
 
2217
2357
 
@@ -2246,8 +2386,8 @@ class SameTypeShape(PrimitiveWithInfer):
2246
2386
  return x
2247
2387
 
2248
2388
  def __infer__(self, x, y):
2249
- validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
2250
- validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
2389
+ validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
2390
+ validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
2251
2391
  validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
2252
2392
  validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
2253
2393
  return x
@@ -2374,13 +2514,15 @@ class ConvertToAdapterTensor(Primitive):
2374
2514
  >>> print(x)
2375
2515
  [1 2 3]
2376
2516
  """
2517
+
2377
2518
  @prim_attr_register
2378
2519
  def __init__(self):
2379
2520
  """Initialize"""
2380
2521
 
2381
2522
  def __call__(self, x):
2382
- """run in PyNative mode"""
2383
- return ms_adapter_registry.tensor(x, inner=True)
2523
+ """Run in PyNative mode"""
2524
+ return ms_adapter_registry.tensor(x, cast_tensor=True)
2525
+
2384
2526
 
2385
2527
  convert_to_adapter_tensor = ConvertToAdapterTensor()
2386
2528
 
@@ -2405,13 +2547,17 @@ class ConvertToMsTensor(Primitive):
2405
2547
  >>> print(x)
2406
2548
  [1 2 3]
2407
2549
  """
2550
+
2408
2551
  @prim_attr_register
2409
2552
  def __init__(self):
2410
2553
  """Initialize"""
2411
2554
 
2412
2555
  def __call__(self, x):
2413
- """run in PyNative mode"""
2414
- return Tensor(x)
2556
+ """Run in PyNative mode"""
2557
+ if isinstance(x, StubTensor):
2558
+ return StubTensor(stub=x.stub, tensor=x.tensor)
2559
+ return ops.deepcopy(x)
2560
+
2415
2561
 
2416
2562
  convert_to_ms_tensor = ConvertToMsTensor()
2417
2563
 
@@ -2458,6 +2604,7 @@ class IsParameter(PrimitiveWithInfer):
2458
2604
  """
2459
2605
  Check if input is `Parameter`
2460
2606
  """
2607
+
2461
2608
  @prim_attr_register
2462
2609
  def __init__(self):
2463
2610
  """Initialize IsParameter"""
@@ -2468,7 +2615,7 @@ class IsParameter(PrimitiveWithInfer):
2468
2615
  def __infer__(self, x):
2469
2616
  return {'shape': [],
2470
2617
  'dtype': mstype.bool_,
2471
- 'value': isinstance(x['dtype'], mstype.ref_type)}
2618
+ 'value': isinstance(x['dtype'], mstype.RefType)}
2472
2619
 
2473
2620
 
2474
2621
  class SiLU(Primitive):
@@ -2547,3 +2694,98 @@ class SetitemTensorIndexInfo(Primitive):
2547
2694
 
2548
2695
  def __call__(self, data, index, value):
2549
2696
  return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
2697
+
2698
+
2699
+ class IsConstant(Primitive):
2700
+ r"""
2701
+ Check if the input is constant
2702
+ """
2703
+
2704
+ @prim_attr_register
2705
+ def __init__(self):
2706
+ """Initialize IsConstant"""
2707
+
2708
+ def __call__(self, x):
2709
+ return True
2710
+
2711
+
2712
+ class SelectView(Primitive):
2713
+ r"""
2714
+ Select tensor of view
2715
+ """
2716
+
2717
+ @prim_attr_register
2718
+ def __init__(self):
2719
+ self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
2720
+
2721
+
2722
+ class CopyWithSlice(Primitive):
2723
+ r"""
2724
+ Copy data to discontinuous tensor
2725
+ """
2726
+ @prim_attr_register
2727
+ def __init__(self):
2728
+ self.add_prim_attr('side_effect_mem', True)
2729
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2730
+
2731
+
2732
+ class MoeFFN(Primitive):
2733
+ r"""
2734
+ The MoeFFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2735
+
2736
+ Args:
2737
+ activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2738
+ Only support 'fastgelu' for now. Default: "fastgelu".
2739
+
2740
+ Inputs:
2741
+ - **x** (Tensor) - The input tensor with data type of int8, float16.
2742
+ Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2743
+ - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2744
+ Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2745
+ indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2746
+ the 2th expert do noting and so on.
2747
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
+ - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2750
+ Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2751
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2752
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2753
+ - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2754
+ Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2755
+ - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
2756
+ - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
2757
+ - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
2758
+ - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
2759
+
2760
+ Outputs:
2761
+ Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
2762
+
2763
+ Supported Platforms:
2764
+ ``Ascend``
2765
+
2766
+ Examples:
2767
+ >>> from mindspore.ops.operations import _inner_ops
2768
+ >>> b = 4
2769
+ >>> s = 128
2770
+ >>> h = 1024
2771
+ >>> h_f = 4 * h
2772
+ >>> e = 16
2773
+ >>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
2774
+ >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2775
+ >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2776
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2777
+ >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2778
+ >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2779
+ >>> moe_ffn = _inner_ops.MoeFFN("fastgelu")
2780
+ >>> output = moe_ffn(x, w1, bias1, w2, bias2)
2781
+ >>> print(output)
2782
+ """
2783
+
2784
+ @prim_attr_register
2785
+ def __init__(self, activation):
2786
+ """Initialize MoeFFN."""
2787
+ self.init_prim_io_names(inputs=["x", "expert_tokens", "weight1", "bias1",
2788
+ "weight2", "bias2", "scale", "offset", "deq_scale1"
2789
+ "deq_scale2"],
2790
+ outputs=["y"])
2791
+ self.activation = activation
@@ -486,9 +486,9 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
486
486
  will enjoy the automatic dtype/shape infer for free.
487
487
 
488
488
  Args:
489
- fn (Function): The Python function that will be run as a custom operator. Default: None.
490
- reg_info (tuple[str, dict]): Each item represents registration information in json format. Default: None.
491
- compile_attrs (Dict): The Python object is used to distinguish the compiled function. Default: None.
489
+ fn (Function): The Python function that will be run as a custom operator. Default: ``None`` .
490
+ reg_info (tuple[str, dict]): Each item represents registration information in json format. Default: ``None`` .
491
+ compile_attrs (Dict): The Python object is used to distinguish the compiled function. Default: ``None`` .
492
492
 
493
493
  Returns:
494
494
  Function, if `fn` is not None, returns a callable function that will execute the Hybrid DSL function;