mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -19,22 +19,24 @@ from collections.abc import Iterable
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
21
21
|
from mindspore.common import Tensor
|
|
22
|
+
from mindspore.common._stub_tensor import StubTensor
|
|
22
23
|
from mindspore.ops import composite as C
|
|
23
24
|
from mindspore.ops.operations.array_ops import Cast
|
|
24
25
|
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
26
|
+
from mindspore.ops.operations.comm_ops import ReduceOp
|
|
25
27
|
from mindspore.ops import signature as sig
|
|
26
28
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
27
|
-
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive
|
|
28
|
-
|
|
29
|
+
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,\
|
|
30
|
+
_run_op, _check_contains_variable
|
|
29
31
|
from mindspore._c_expression import Tensor as Tensor_
|
|
30
32
|
from mindspore._c_expression import typing
|
|
31
33
|
from mindspore import _checkparam as validator
|
|
32
34
|
from mindspore.common import dtype as mstype
|
|
33
35
|
from mindspore.common.parameter import Parameter
|
|
34
|
-
from mindspore.communication.management import GlobalComm
|
|
36
|
+
from mindspore.communication.management import GlobalComm, get_rank
|
|
35
37
|
from mindspore.common.api import _pynative_executor
|
|
36
38
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
37
|
-
|
|
39
|
+
from mindspore import ops
|
|
38
40
|
|
|
39
41
|
# Bit operation
|
|
40
42
|
bit_and = bit_and()
|
|
@@ -73,12 +75,11 @@ class ExtractImagePatches(Primitive):
|
|
|
73
75
|
- valid: Means that the taken patch area must be completely covered in the original image.
|
|
74
76
|
|
|
75
77
|
Inputs:
|
|
76
|
-
- **input_x** (Tensor) - A 4-D tensor whose shape is
|
|
77
|
-
data type is number.
|
|
78
|
+
- **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(in\_batch, in\_depth, in\_row, in\_col)`.
|
|
78
79
|
|
|
79
80
|
Outputs:
|
|
80
|
-
Tensor, a 4-D tensor whose data type is same as 'input_x',
|
|
81
|
-
|
|
81
|
+
Tensor, a 4-D tensor whose data type is same as 'input_x', and the shape
|
|
82
|
+
is :math:`(out\_batch, out\_depth, out\_row, out\_col)`,where the out_batch is the same as the in_batch
|
|
82
83
|
and
|
|
83
84
|
|
|
84
85
|
.. math::
|
|
@@ -121,7 +122,6 @@ class ExtractImagePatches(Primitive):
|
|
|
121
122
|
validator.check_value_type('padding', padding, [str], self.name)
|
|
122
123
|
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
|
123
124
|
self.add_prim_attr("padding", self.padding)
|
|
124
|
-
self.is_ge = context.get_context("enable_ge")
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
class Quant(PrimitiveWithInfer):
|
|
@@ -144,7 +144,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
144
144
|
Args:
|
|
145
145
|
scale (float) : Specifies the scaling ratio.
|
|
146
146
|
offset (float): Specifies the offset.
|
|
147
|
-
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False
|
|
147
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
148
148
|
round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
|
|
149
149
|
Default: "Round".
|
|
150
150
|
|
|
@@ -172,7 +172,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
172
172
|
return x_shape
|
|
173
173
|
|
|
174
174
|
def infer_dtype(self, x_type):
|
|
175
|
-
validator.check_subclass("input_x", x_type, mstype.
|
|
175
|
+
validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
|
|
176
176
|
validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
|
|
177
177
|
return mstype.int8
|
|
178
178
|
|
|
@@ -254,8 +254,8 @@ class Dequant(PrimitiveWithInfer):
|
|
|
254
254
|
This operation only support Ascend 310 inference environment.
|
|
255
255
|
|
|
256
256
|
Args:
|
|
257
|
-
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False
|
|
258
|
-
relu_flag (bool): Specifies whether to perform ReLU. Default: False
|
|
257
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
258
|
+
relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
|
|
259
259
|
|
|
260
260
|
Inputs:
|
|
261
261
|
- **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
|
|
@@ -281,7 +281,7 @@ class Dequant(PrimitiveWithInfer):
|
|
|
281
281
|
return x_shape
|
|
282
282
|
|
|
283
283
|
def infer_dtype(self, x_type, deq_scale_type):
|
|
284
|
-
validator.check_subclass("x", x_type, mstype.
|
|
284
|
+
validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
|
|
285
285
|
validator.check_type_name("x", x_type, [mstype.int32], self.name)
|
|
286
286
|
validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
|
|
287
287
|
return mstype.float16
|
|
@@ -502,6 +502,109 @@ class Receive(PrimitiveWithInfer):
|
|
|
502
502
|
return self.get_attr_dict()['dtype']
|
|
503
503
|
|
|
504
504
|
|
|
505
|
+
class Reduce(PrimitiveWithInfer):
|
|
506
|
+
"""
|
|
507
|
+
Reduces tensor across the processes in the specified communication group.
|
|
508
|
+
|
|
509
|
+
Note:
|
|
510
|
+
Only process with destination rank receives the reduced output.
|
|
511
|
+
Other processes only get a tensor with shape [1], which has no mathematical meaning.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
dest_rank (int): Specifies the rank of the process that receives the reduced output.
|
|
515
|
+
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
516
|
+
On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
|
|
517
|
+
group (str, optional): The communication group to work on.
|
|
518
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
519
|
+
|
|
520
|
+
Inputs:
|
|
521
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
522
|
+
|
|
523
|
+
Examples:
|
|
524
|
+
>>> import mindspore.ops as ops
|
|
525
|
+
>>> import mindspore.nn as nn
|
|
526
|
+
>>> from mindspore.communication import init
|
|
527
|
+
>>> from mindspore import Tensor
|
|
528
|
+
>>> import numpy as np
|
|
529
|
+
>>> # Launch 4 processes.
|
|
530
|
+
>>> init()
|
|
531
|
+
>>> class ReduceNet(nn.Cell):
|
|
532
|
+
>>> def __init__(self):
|
|
533
|
+
>>> super(Net, self).__init__()
|
|
534
|
+
>>> self.reduce = ops.Reduce(dest_rank=1)
|
|
535
|
+
>>>
|
|
536
|
+
>>> def construct(self, x):
|
|
537
|
+
>>> out = self.reduce(x)
|
|
538
|
+
>>> return out
|
|
539
|
+
>>> input = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
540
|
+
>>> net = ReduceNet()
|
|
541
|
+
>>> output = net(input)
|
|
542
|
+
>>> print(output)
|
|
543
|
+
Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
|
|
544
|
+
[4. 4. 4. 4. 4. 4. 4. 4.]],
|
|
545
|
+
Other proesses: [0.].
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
@prim_attr_register
|
|
549
|
+
def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
550
|
+
self.dest_rank = dest_rank
|
|
551
|
+
self.op = op
|
|
552
|
+
self.group = group
|
|
553
|
+
|
|
554
|
+
def infer_shape(self, x_shape):
|
|
555
|
+
# The process with dest_rank returns the reduced output.
|
|
556
|
+
# Other processes only gets a tensor with shape [1], which has no mathematical meaning.
|
|
557
|
+
if self.dest_rank == get_rank():
|
|
558
|
+
return x_shape
|
|
559
|
+
return [1]
|
|
560
|
+
|
|
561
|
+
def infer_dtype(self, x_dtype):
|
|
562
|
+
return x_dtype
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class Barrier(PrimitiveWithInfer):
|
|
566
|
+
"""
|
|
567
|
+
Synchronizes all processes in the specified group.
|
|
568
|
+
|
|
569
|
+
Note:
|
|
570
|
+
After calling this collective operator,
|
|
571
|
+
this process will be blocked until all other processes in the group call this operator.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
group (str, optional): The communication group to work on.
|
|
575
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
576
|
+
|
|
577
|
+
Examples:
|
|
578
|
+
>>> import mindspore.ops as ops
|
|
579
|
+
>>> import mindspore.nn as nn
|
|
580
|
+
>>> from mindspore.communication import init
|
|
581
|
+
>>> from mindspore import Tensor
|
|
582
|
+
>>> import numpy as np
|
|
583
|
+
>>> # Launch 4 processes.
|
|
584
|
+
>>> init()
|
|
585
|
+
>>> class BarrierNet(nn.Cell):
|
|
586
|
+
>>> def __init__(self):
|
|
587
|
+
>>> super(Net, self).__init__()
|
|
588
|
+
>>> self.barrier = ops.Barrier()
|
|
589
|
+
>>>
|
|
590
|
+
>>> def construct(self):
|
|
591
|
+
>>> self.barrier()
|
|
592
|
+
>>> net = BarrierNet()
|
|
593
|
+
>>> net()
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
@prim_attr_register
|
|
597
|
+
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
598
|
+
self.group = group
|
|
599
|
+
self.add_prim_attr("side_effect_mem", True)
|
|
600
|
+
|
|
601
|
+
def infer_shape(self):
|
|
602
|
+
return [1]
|
|
603
|
+
|
|
604
|
+
def infer_dtype(self):
|
|
605
|
+
return mstype.float32
|
|
606
|
+
|
|
607
|
+
|
|
505
608
|
class MatrixSetDiag(PrimitiveWithInfer):
|
|
506
609
|
r"""
|
|
507
610
|
Modifies the batched diagonal part of a batched tensor.
|
|
@@ -604,9 +707,9 @@ class ConfusionMulGrad(PrimitiveWithInfer):
|
|
|
604
707
|
return outshape0, outshape1
|
|
605
708
|
|
|
606
709
|
def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
|
|
607
|
-
validator.check_subclass("input0_dtype", input0_dtype, mstype.
|
|
608
|
-
validator.check_subclass("input1_dtype", input1_dtype, mstype.
|
|
609
|
-
validator.check_subclass("input2_dtype", input2_dtype, mstype.
|
|
710
|
+
validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
|
|
711
|
+
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
|
|
712
|
+
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
|
|
610
713
|
return input0_dtype, input1_dtype
|
|
611
714
|
|
|
612
715
|
|
|
@@ -619,7 +722,7 @@ class ConvertToDynamic(PrimitiveWithCheck):
|
|
|
619
722
|
|
|
620
723
|
Args:
|
|
621
724
|
is_dynamic_rank (bool): If true, convert to dynamic rank.
|
|
622
|
-
If false, convert to dynamic shape. Default: False
|
|
725
|
+
If false, convert to dynamic shape. Default: ``False``.
|
|
623
726
|
|
|
624
727
|
Inputs:
|
|
625
728
|
- **input** (Tensor) - The tensor used for testing.
|
|
@@ -664,7 +767,7 @@ class ConvertToDynamic(PrimitiveWithCheck):
|
|
|
664
767
|
validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
|
|
665
768
|
|
|
666
769
|
def check_dtype(self, input_dtype):
|
|
667
|
-
validator.check_subclass("input_dtype", input_dtype, mstype.
|
|
770
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
668
771
|
|
|
669
772
|
|
|
670
773
|
class GpuConvertToDynamicShape(PrimitiveWithCheck):
|
|
@@ -714,7 +817,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck):
|
|
|
714
817
|
validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
|
|
715
818
|
|
|
716
819
|
def check_dtype(self, input_dtype):
|
|
717
|
-
validator.check_subclass("input_dtype", input_dtype, mstype.
|
|
820
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
718
821
|
|
|
719
822
|
|
|
720
823
|
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
|
@@ -766,7 +869,7 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
|
|
766
869
|
|
|
767
870
|
def infer_type(self, input_dtype):
|
|
768
871
|
"""Infer the dtype of input for ErrorOnDynamicShapeInput."""
|
|
769
|
-
validator.check_subclass("input_dtype", input_dtype, mstype.
|
|
872
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
770
873
|
return input_dtype
|
|
771
874
|
|
|
772
875
|
def infer_value(self, input_tensor):
|
|
@@ -816,7 +919,7 @@ class SequenceMask(PrimitiveWithCheck):
|
|
|
816
919
|
validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
|
|
817
920
|
|
|
818
921
|
def check_dtype(self, lengths_dtype, maxlen_dtype):
|
|
819
|
-
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.
|
|
922
|
+
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
|
|
820
923
|
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
|
|
821
924
|
|
|
822
925
|
|
|
@@ -1170,8 +1273,8 @@ class DynamicStitch(PrimitiveWithCheck):
|
|
|
1170
1273
|
return out_shape
|
|
1171
1274
|
|
|
1172
1275
|
def check_dtype(self, indices_type, data_type):
|
|
1173
|
-
validator.check_subclass("indices[0]", indices_type[0], mstype.
|
|
1174
|
-
validator.check_subclass("data[0]", data_type[0], mstype.
|
|
1276
|
+
validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
|
|
1277
|
+
validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
|
|
1175
1278
|
indices_num = len(indices_type)
|
|
1176
1279
|
for i in range(0, indices_num):
|
|
1177
1280
|
validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
|
|
@@ -1418,6 +1521,7 @@ class DecodeImage(PrimitiveWithInfer):
|
|
|
1418
1521
|
|
|
1419
1522
|
Examples:
|
|
1420
1523
|
"""
|
|
1524
|
+
|
|
1421
1525
|
@prim_attr_register
|
|
1422
1526
|
def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
|
|
1423
1527
|
_op_max_size=[8000000]):
|
|
@@ -1467,7 +1571,7 @@ class DynamicBroadcastTo(Primitive):
|
|
|
1467
1571
|
Inputs:
|
|
1468
1572
|
- **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
|
|
1469
1573
|
float16, float32, int32, int8, uint8.
|
|
1470
|
-
The shape is :math:`(N,*)` where :math:`*` means
|
|
1574
|
+
The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
|
|
1471
1575
|
- **shape** (Tensor): The target shape to broadcast.
|
|
1472
1576
|
|
|
1473
1577
|
Outputs:
|
|
@@ -1495,6 +1599,16 @@ class Cummin(Primitive):
|
|
|
1495
1599
|
|
|
1496
1600
|
Refer to :func:`mindspore.ops.cummin` for more detail.
|
|
1497
1601
|
|
|
1602
|
+
Args:
|
|
1603
|
+
axis (int): The axis to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)).
|
|
1604
|
+
|
|
1605
|
+
Inputs:
|
|
1606
|
+
- **input** (Tensor) - The input tensor.
|
|
1607
|
+
|
|
1608
|
+
Outputs:
|
|
1609
|
+
A tuple of 2 Tensors(values, indices), containing the cumulative minimum of elements and the index,
|
|
1610
|
+
The shape of each output tensor is the same as input `input`.
|
|
1611
|
+
|
|
1498
1612
|
Supported Platforms:
|
|
1499
1613
|
``Ascend`` ``GPU`` ``CPU``
|
|
1500
1614
|
|
|
@@ -1509,6 +1623,7 @@ class Cummin(Primitive):
|
|
|
1509
1623
|
>>> print(output[1])
|
|
1510
1624
|
[0 1 1 1 4 4]
|
|
1511
1625
|
"""
|
|
1626
|
+
|
|
1512
1627
|
@prim_attr_register
|
|
1513
1628
|
def __init__(self, axis):
|
|
1514
1629
|
"""Initialize Cummin"""
|
|
@@ -1528,7 +1643,7 @@ class DynamicResizeNearestNeighbor(Primitive):
|
|
|
1528
1643
|
|
|
1529
1644
|
Args:
|
|
1530
1645
|
align_corners (bool): Whether the centers of the 4 corner pixels of the input
|
|
1531
|
-
and output tensors are aligned. Default: False
|
|
1646
|
+
and output tensors are aligned. Default: ``False``.
|
|
1532
1647
|
|
|
1533
1648
|
Inputs:
|
|
1534
1649
|
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
|
@@ -1613,7 +1728,7 @@ class PsROIPooling(PrimitiveWithInfer):
|
|
|
1613
1728
|
return output_shape, output_map_shape
|
|
1614
1729
|
|
|
1615
1730
|
def infer_dtype(self, inputs_type, rois_type):
|
|
1616
|
-
map_type = mstype.
|
|
1731
|
+
map_type = mstype.TensorType(mstype.int32)
|
|
1617
1732
|
return inputs_type, map_type
|
|
1618
1733
|
|
|
1619
1734
|
|
|
@@ -1671,8 +1786,10 @@ class PartitionedCall(PrimitiveWithInfer):
|
|
|
1671
1786
|
|
|
1672
1787
|
Examples:
|
|
1673
1788
|
"""
|
|
1789
|
+
|
|
1674
1790
|
@prim_attr_register
|
|
1675
1791
|
def __init__(self, graph, executor_type=""):
|
|
1792
|
+
super(PartitionedCall, self).__init__(self.__class__.__name__)
|
|
1676
1793
|
self.add_prim_attr("executor_type", executor_type)
|
|
1677
1794
|
self.graph = graph
|
|
1678
1795
|
|
|
@@ -1744,9 +1861,6 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1744
1861
|
def __call__(self, args):
|
|
1745
1862
|
if not isinstance(args, tuple):
|
|
1746
1863
|
args = (args,)
|
|
1747
|
-
for arg in args:
|
|
1748
|
-
if isinstance(arg, Parameter) and arg.has_init:
|
|
1749
|
-
arg.init_data()
|
|
1750
1864
|
return _run_op(self, self.name, args)
|
|
1751
1865
|
|
|
1752
1866
|
def infer_shape(self, *inputs_shape):
|
|
@@ -1832,16 +1946,32 @@ class Format(PrimitiveWithInfer):
|
|
|
1832
1946
|
def __init__(self):
|
|
1833
1947
|
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
|
1834
1948
|
|
|
1949
|
+
|
|
1835
1950
|
def __infer__(self, str_, *var):
|
|
1836
|
-
|
|
1951
|
+
def check_variable(str_, var):
|
|
1952
|
+
if _check_contains_variable(str_['dtype'], str_['value']):
|
|
1953
|
+
return True
|
|
1954
|
+
|
|
1955
|
+
for item in var:
|
|
1956
|
+
if _check_contains_variable(item['dtype'], item['value']):
|
|
1957
|
+
return True
|
|
1958
|
+
return False
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
if check_variable(str_, var):
|
|
1962
|
+
return {'dtype': mstype.string, 'shape': [], 'value': None}
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
str_value = str_['value']
|
|
1966
|
+
kwargs = dict()
|
|
1837
1967
|
var_value = list()
|
|
1838
|
-
|
|
1839
|
-
raise ValueError("str.format not support to input a variable.")
|
|
1968
|
+
|
|
1840
1969
|
for item in var:
|
|
1841
|
-
if item["
|
|
1842
|
-
|
|
1970
|
+
if isinstance(item["dtype"], typing.Keyword):
|
|
1971
|
+
kwargs.update(item["value"])
|
|
1843
1972
|
var_value.append(item["value"])
|
|
1844
|
-
|
|
1973
|
+
|
|
1974
|
+
value = str_value.format(*var_value, **kwargs)
|
|
1845
1975
|
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
|
1846
1976
|
|
|
1847
1977
|
|
|
@@ -1982,7 +2112,7 @@ class ClipByNorm(PrimitiveWithInfer):
|
|
|
1982
2112
|
|
|
1983
2113
|
Args:
|
|
1984
2114
|
axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
|
|
1985
|
-
Default: None
|
|
2115
|
+
Default: ``None``, all dimensions to calculate.
|
|
1986
2116
|
|
|
1987
2117
|
Inputs:
|
|
1988
2118
|
- **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
|
|
@@ -2060,8 +2190,8 @@ class TopTypeof(Primitive):
|
|
|
2060
2190
|
'slice': mstype.Slice(),
|
|
2061
2191
|
'list': mstype.List(),
|
|
2062
2192
|
'tuple': mstype.Tuple(),
|
|
2063
|
-
'Tensor': mstype.
|
|
2064
|
-
'NoneType': mstype.
|
|
2193
|
+
'Tensor': mstype.tensor_type,
|
|
2194
|
+
'NoneType': mstype.NoneType(),
|
|
2065
2195
|
'int': mstype.Int(),
|
|
2066
2196
|
'bool': mstype.Bool(),
|
|
2067
2197
|
'ellipsis': mstype.Ellipsis_(),
|
|
@@ -2098,7 +2228,7 @@ class MixedPrecisionCast(Primitive):
|
|
|
2098
2228
|
Examples:
|
|
2099
2229
|
>>> import numpy as np
|
|
2100
2230
|
>>> from mindspore import Tensor
|
|
2101
|
-
>>> from mindspore
|
|
2231
|
+
>>> from mindspore import dtype as mstype
|
|
2102
2232
|
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
2103
2233
|
>>> x = Tensor(np.ones([2, 3], dtype=np.float32))
|
|
2104
2234
|
>>> out = inner.MixedPrecisionCast(mstype.float16, x)
|
|
@@ -2175,13 +2305,22 @@ class CheckBprop(PrimitiveWithInfer):
|
|
|
2175
2305
|
raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
|
|
2176
2306
|
f"the number of input arguments except 'out' and 'dout', "
|
|
2177
2307
|
f"which is:{len(yshapes)} but got {len(xshapes)}.")
|
|
2178
|
-
|
|
2179
|
-
|
|
2180
|
-
|
|
2181
|
-
|
|
2308
|
+
|
|
2309
|
+
def shape_equal(shape1, shape2):
|
|
2310
|
+
if len(shape1) != len(shape2):
|
|
2311
|
+
return False
|
|
2312
|
+
for shape_axis1, shape_axis2 in zip(shape1, shape2):
|
|
2313
|
+
if shape_axis1 == -1 or shape_axis2 == -1:
|
|
2314
|
+
continue
|
|
2315
|
+
if shape_axis1 != shape_axis2:
|
|
2316
|
+
return False
|
|
2317
|
+
return True
|
|
2318
|
+
|
|
2319
|
+
for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
|
|
2182
2320
|
if not xshape or not yshape:
|
|
2183
2321
|
continue
|
|
2184
|
-
|
|
2322
|
+
|
|
2323
|
+
if not shape_equal(xshape, yshape):
|
|
2185
2324
|
raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
|
|
2186
2325
|
f"should have the same shape as the {i}th argument, "
|
|
2187
2326
|
f"which is:{yshape}, but got: {xshape}.")
|
|
@@ -2200,18 +2339,19 @@ class CheckBprop(PrimitiveWithInfer):
|
|
|
2200
2339
|
for i in range(checking_range):
|
|
2201
2340
|
xdtype = xdtypes[i]
|
|
2202
2341
|
ydtype = ydtypes[i]
|
|
2203
|
-
if isinstance(xdtype, mstype.
|
|
2342
|
+
if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
|
|
2204
2343
|
continue
|
|
2205
|
-
if isinstance(ydtype, mstype.
|
|
2206
|
-
if not isinstance(xdtype, mstype.
|
|
2344
|
+
if isinstance(ydtype, mstype.FunctionType):
|
|
2345
|
+
if not isinstance(xdtype, mstype.EnvType):
|
|
2207
2346
|
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
|
|
2208
|
-
f"should be {mstype.
|
|
2347
|
+
f"should be {mstype.EnvType}, but got {xdtype}.")
|
|
2209
2348
|
if xdtype != ydtype:
|
|
2210
2349
|
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
|
|
2211
2350
|
f"should have the same dtype as the {i}th argument, "
|
|
2212
2351
|
f"which is:{ydtype}, but got: {xdtype}.")
|
|
2213
2352
|
return xdtypes
|
|
2214
2353
|
|
|
2354
|
+
|
|
2215
2355
|
check_bprop = CheckBprop()
|
|
2216
2356
|
|
|
2217
2357
|
|
|
@@ -2246,8 +2386,8 @@ class SameTypeShape(PrimitiveWithInfer):
|
|
|
2246
2386
|
return x
|
|
2247
2387
|
|
|
2248
2388
|
def __infer__(self, x, y):
|
|
2249
|
-
validator.check_subclass('x', x['dtype'], mstype.
|
|
2250
|
-
validator.check_subclass('y', y['dtype'], mstype.
|
|
2389
|
+
validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
|
|
2390
|
+
validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
|
|
2251
2391
|
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
|
|
2252
2392
|
validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
|
|
2253
2393
|
return x
|
|
@@ -2374,13 +2514,15 @@ class ConvertToAdapterTensor(Primitive):
|
|
|
2374
2514
|
>>> print(x)
|
|
2375
2515
|
[1 2 3]
|
|
2376
2516
|
"""
|
|
2517
|
+
|
|
2377
2518
|
@prim_attr_register
|
|
2378
2519
|
def __init__(self):
|
|
2379
2520
|
"""Initialize"""
|
|
2380
2521
|
|
|
2381
2522
|
def __call__(self, x):
|
|
2382
|
-
"""
|
|
2383
|
-
return ms_adapter_registry.tensor(x,
|
|
2523
|
+
"""Run in PyNative mode"""
|
|
2524
|
+
return ms_adapter_registry.tensor(x, cast_tensor=True)
|
|
2525
|
+
|
|
2384
2526
|
|
|
2385
2527
|
convert_to_adapter_tensor = ConvertToAdapterTensor()
|
|
2386
2528
|
|
|
@@ -2405,13 +2547,17 @@ class ConvertToMsTensor(Primitive):
|
|
|
2405
2547
|
>>> print(x)
|
|
2406
2548
|
[1 2 3]
|
|
2407
2549
|
"""
|
|
2550
|
+
|
|
2408
2551
|
@prim_attr_register
|
|
2409
2552
|
def __init__(self):
|
|
2410
2553
|
"""Initialize"""
|
|
2411
2554
|
|
|
2412
2555
|
def __call__(self, x):
|
|
2413
|
-
"""
|
|
2414
|
-
|
|
2556
|
+
"""Run in PyNative mode"""
|
|
2557
|
+
if isinstance(x, StubTensor):
|
|
2558
|
+
return StubTensor(stub=x.stub, tensor=x.tensor)
|
|
2559
|
+
return ops.deepcopy(x)
|
|
2560
|
+
|
|
2415
2561
|
|
|
2416
2562
|
convert_to_ms_tensor = ConvertToMsTensor()
|
|
2417
2563
|
|
|
@@ -2458,6 +2604,7 @@ class IsParameter(PrimitiveWithInfer):
|
|
|
2458
2604
|
"""
|
|
2459
2605
|
Check if input is `Parameter`
|
|
2460
2606
|
"""
|
|
2607
|
+
|
|
2461
2608
|
@prim_attr_register
|
|
2462
2609
|
def __init__(self):
|
|
2463
2610
|
"""Initialize IsParameter"""
|
|
@@ -2468,7 +2615,7 @@ class IsParameter(PrimitiveWithInfer):
|
|
|
2468
2615
|
def __infer__(self, x):
|
|
2469
2616
|
return {'shape': [],
|
|
2470
2617
|
'dtype': mstype.bool_,
|
|
2471
|
-
'value': isinstance(x['dtype'], mstype.
|
|
2618
|
+
'value': isinstance(x['dtype'], mstype.RefType)}
|
|
2472
2619
|
|
|
2473
2620
|
|
|
2474
2621
|
class SiLU(Primitive):
|
|
@@ -2547,3 +2694,98 @@ class SetitemTensorIndexInfo(Primitive):
|
|
|
2547
2694
|
|
|
2548
2695
|
def __call__(self, data, index, value):
|
|
2549
2696
|
return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
|
|
2697
|
+
|
|
2698
|
+
|
|
2699
|
+
class IsConstant(Primitive):
|
|
2700
|
+
r"""
|
|
2701
|
+
Check if the input is constant
|
|
2702
|
+
"""
|
|
2703
|
+
|
|
2704
|
+
@prim_attr_register
|
|
2705
|
+
def __init__(self):
|
|
2706
|
+
"""Initialize IsConstant"""
|
|
2707
|
+
|
|
2708
|
+
def __call__(self, x):
|
|
2709
|
+
return True
|
|
2710
|
+
|
|
2711
|
+
|
|
2712
|
+
class SelectView(Primitive):
|
|
2713
|
+
r"""
|
|
2714
|
+
Select tensor of view
|
|
2715
|
+
"""
|
|
2716
|
+
|
|
2717
|
+
@prim_attr_register
|
|
2718
|
+
def __init__(self):
|
|
2719
|
+
self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
|
|
2720
|
+
|
|
2721
|
+
|
|
2722
|
+
class CopyWithSlice(Primitive):
|
|
2723
|
+
r"""
|
|
2724
|
+
Copy data to discontinuous tensor
|
|
2725
|
+
"""
|
|
2726
|
+
@prim_attr_register
|
|
2727
|
+
def __init__(self):
|
|
2728
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2729
|
+
self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
|
|
2730
|
+
|
|
2731
|
+
|
|
2732
|
+
class MoeFFN(Primitive):
|
|
2733
|
+
r"""
|
|
2734
|
+
The MoeFFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
|
|
2735
|
+
|
|
2736
|
+
Args:
|
|
2737
|
+
activation (string): The activation type, set to 'fastgelu' or 'gelu'.
|
|
2738
|
+
Only support 'fastgelu' for now. Default: "fastgelu".
|
|
2739
|
+
|
|
2740
|
+
Inputs:
|
|
2741
|
+
- **x** (Tensor) - The input tensor with data type of int8, float16.
|
|
2742
|
+
Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
|
|
2743
|
+
- **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
|
|
2744
|
+
Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
|
|
2745
|
+
indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
|
|
2746
|
+
the 2th expert do noting and so on.
|
|
2747
|
+
- **weight1** (Tensor) - The weight1 tensor with data type of float16.
|
|
2748
|
+
Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
|
|
2749
|
+
- **bias1** (Tensor) - The bias1 tensor with data type of float16.
|
|
2750
|
+
Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
|
|
2751
|
+
- **weight2** (Tensor) - The weight2 tensor with data type of float16.
|
|
2752
|
+
Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
|
|
2753
|
+
- **bias2** (Tensor) - The bias2 tensor with data type of float16.
|
|
2754
|
+
Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
|
|
2755
|
+
- **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
|
|
2756
|
+
- **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
|
|
2757
|
+
- **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
|
|
2758
|
+
- **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
|
|
2759
|
+
|
|
2760
|
+
Outputs:
|
|
2761
|
+
Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
|
|
2762
|
+
|
|
2763
|
+
Supported Platforms:
|
|
2764
|
+
``Ascend``
|
|
2765
|
+
|
|
2766
|
+
Examples:
|
|
2767
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2768
|
+
>>> b = 4
|
|
2769
|
+
>>> s = 128
|
|
2770
|
+
>>> h = 1024
|
|
2771
|
+
>>> h_f = 4 * h
|
|
2772
|
+
>>> e = 16
|
|
2773
|
+
>>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
|
|
2774
|
+
>>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
|
|
2775
|
+
>>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
|
|
2776
|
+
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2777
|
+
>>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
|
|
2778
|
+
>>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
|
|
2779
|
+
>>> moe_ffn = _inner_ops.MoeFFN("fastgelu")
|
|
2780
|
+
>>> output = moe_ffn(x, w1, bias1, w2, bias2)
|
|
2781
|
+
>>> print(output)
|
|
2782
|
+
"""
|
|
2783
|
+
|
|
2784
|
+
@prim_attr_register
|
|
2785
|
+
def __init__(self, activation):
|
|
2786
|
+
"""Initialize MoeFFN."""
|
|
2787
|
+
self.init_prim_io_names(inputs=["x", "expert_tokens", "weight1", "bias1",
|
|
2788
|
+
"weight2", "bias2", "scale", "offset", "deq_scale1"
|
|
2789
|
+
"deq_scale2"],
|
|
2790
|
+
outputs=["y"])
|
|
2791
|
+
self.activation = activation
|
|
@@ -486,9 +486,9 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
|
|
|
486
486
|
will enjoy the automatic dtype/shape infer for free.
|
|
487
487
|
|
|
488
488
|
Args:
|
|
489
|
-
fn (Function): The Python function that will be run as a custom operator. Default: None.
|
|
490
|
-
reg_info (tuple[str, dict]): Each item represents registration information in json format. Default: None.
|
|
491
|
-
compile_attrs (Dict): The Python object is used to distinguish the compiled function. Default: None.
|
|
489
|
+
fn (Function): The Python function that will be run as a custom operator. Default: ``None`` .
|
|
490
|
+
reg_info (tuple[str, dict]): Each item represents registration information in json format. Default: ``None`` .
|
|
491
|
+
compile_attrs (Dict): The Python object is used to distinguish the compiled function. Default: ``None`` .
|
|
492
492
|
|
|
493
493
|
Returns:
|
|
494
494
|
Function, if `fn` is not None, returns a callable function that will execute the Hybrid DSL function;
|