mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/rnn_cells.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""RNN Cells module, include RNNCell, GRUCell, LSTMCell."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
|
+
from functools import wraps
|
|
17
18
|
|
|
18
19
|
import math
|
|
19
20
|
import numpy as np
|
|
@@ -39,8 +40,8 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|
|
39
40
|
@constexpr(check=False)
|
|
40
41
|
def _check_is_tensor(param_name, input_data, cls_name):
|
|
41
42
|
"""Internal function, used to check whether the input data is Tensor."""
|
|
42
|
-
if input_data is not None and not isinstance(P.typeof(input_data), mstype.
|
|
43
|
-
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.
|
|
43
|
+
if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
|
|
44
|
+
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
|
|
44
45
|
f"but got '{P.typeof(input_data)}'")
|
|
45
46
|
|
|
46
47
|
|
|
@@ -68,6 +69,8 @@ def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
def _check_lstmcell_init(func):
|
|
72
|
+
"""Internal function, used to check init args."""
|
|
73
|
+
@wraps(func)
|
|
71
74
|
def wrapper(*args, **kwargs):
|
|
72
75
|
logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
|
|
73
76
|
f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
|
|
@@ -142,7 +145,8 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
142
145
|
|
|
143
146
|
class RNNCellBase(Cell):
|
|
144
147
|
'''Basic class for RNN Cells'''
|
|
145
|
-
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int
|
|
148
|
+
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
|
|
149
|
+
dtype=mstype.float32):
|
|
146
150
|
super().__init__()
|
|
147
151
|
validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
|
|
148
152
|
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
|
@@ -150,20 +154,20 @@ class RNNCellBase(Cell):
|
|
|
150
154
|
self.input_size = input_size
|
|
151
155
|
self.hidden_size = hidden_size
|
|
152
156
|
self.has_bias = has_bias
|
|
153
|
-
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size)
|
|
154
|
-
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size)
|
|
157
|
+
self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size), dtype=dtype))
|
|
158
|
+
self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size), dtype=dtype))
|
|
155
159
|
if has_bias:
|
|
156
|
-
self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size)
|
|
157
|
-
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size)
|
|
160
|
+
self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
|
|
161
|
+
self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
|
|
158
162
|
else:
|
|
159
163
|
self.bias_ih = None
|
|
160
164
|
self.bias_hh = None
|
|
161
|
-
self.reset_parameters()
|
|
165
|
+
self.reset_parameters(dtype=dtype)
|
|
162
166
|
|
|
163
|
-
def reset_parameters(self):
|
|
167
|
+
def reset_parameters(self, dtype=mstype.float32):
|
|
164
168
|
stdv = 1 / math.sqrt(self.hidden_size)
|
|
165
169
|
for weight in self.get_parameters():
|
|
166
|
-
weight.set_data(initializer(Uniform(stdv), weight.shape))
|
|
170
|
+
weight.set_data(initializer(Uniform(stdv), weight.shape, dtype))
|
|
167
171
|
|
|
168
172
|
|
|
169
173
|
class RNNCell(RNNCellBase):
|
|
@@ -181,13 +185,14 @@ class RNNCell(RNNCellBase):
|
|
|
181
185
|
Args:
|
|
182
186
|
input_size (int): Number of features of input.
|
|
183
187
|
hidden_size (int): Number of features of hidden layer.
|
|
184
|
-
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
|
185
|
-
nonlinearity (str): The non-linearity to use. Can be either
|
|
188
|
+
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
|
|
189
|
+
nonlinearity (str): The non-linearity to use. Can be either ``"tanh"`` or ``"relu"`` .
|
|
190
|
+
Default: ``"tanh"`` .
|
|
191
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
186
192
|
|
|
187
193
|
Inputs:
|
|
188
194
|
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
|
|
189
195
|
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
|
|
190
|
-
Data type of `hx` must be the same as `x`.
|
|
191
196
|
|
|
192
197
|
Outputs:
|
|
193
198
|
- **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
|
|
@@ -201,9 +206,11 @@ class RNNCell(RNNCellBase):
|
|
|
201
206
|
``Ascend`` ``GPU`` ``CPU``
|
|
202
207
|
|
|
203
208
|
Examples:
|
|
204
|
-
>>>
|
|
205
|
-
>>>
|
|
206
|
-
>>>
|
|
209
|
+
>>> import mindspore as ms
|
|
210
|
+
>>> import numpy as np
|
|
211
|
+
>>> net = ms.nn.RNNCell(10, 16)
|
|
212
|
+
>>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
|
213
|
+
>>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32))
|
|
207
214
|
>>> output = []
|
|
208
215
|
>>> for i in range(5):
|
|
209
216
|
... hx = net(x[i], hx)
|
|
@@ -213,8 +220,9 @@ class RNNCell(RNNCellBase):
|
|
|
213
220
|
"""
|
|
214
221
|
_non_linearity = ['tanh', 'relu']
|
|
215
222
|
|
|
216
|
-
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"
|
|
217
|
-
|
|
223
|
+
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh",
|
|
224
|
+
dtype=mstype.float32):
|
|
225
|
+
super().__init__(input_size, hidden_size, has_bias, num_chunks=1, dtype=dtype)
|
|
218
226
|
validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
|
|
219
227
|
validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
|
|
220
228
|
self.nonlinearity = nonlinearity
|
|
@@ -263,15 +271,16 @@ class LSTMCell(RNNCellBase):
|
|
|
263
271
|
Args:
|
|
264
272
|
input_size (int): Number of features of input.
|
|
265
273
|
hidden_size (int): Number of features of hidden layer.
|
|
266
|
-
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
|
274
|
+
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
|
|
275
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
267
276
|
|
|
268
277
|
Inputs:
|
|
269
|
-
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)
|
|
278
|
+
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
|
|
270
279
|
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
|
271
|
-
and shape :math:`(batch\_size, hidden\_size)
|
|
280
|
+
and shape :math:`(batch\_size, hidden\_size)` .
|
|
272
281
|
|
|
273
282
|
Outputs:
|
|
274
|
-
- **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)
|
|
283
|
+
- **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)` .
|
|
275
284
|
|
|
276
285
|
Raises:
|
|
277
286
|
TypeError: If `input_size`, `hidden_size` is not an int.
|
|
@@ -281,10 +290,12 @@ class LSTMCell(RNNCellBase):
|
|
|
281
290
|
``Ascend`` ``GPU`` ``CPU``
|
|
282
291
|
|
|
283
292
|
Examples:
|
|
284
|
-
>>>
|
|
285
|
-
>>>
|
|
286
|
-
>>>
|
|
287
|
-
>>>
|
|
293
|
+
>>> import mindspore as ms
|
|
294
|
+
>>> import numpy as np
|
|
295
|
+
>>> net = ms.nn.LSTMCell(10, 16)
|
|
296
|
+
>>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
|
297
|
+
>>> h = ms.Tensor(np.ones([3, 16]).astype(np.float32))
|
|
298
|
+
>>> c = ms.Tensor(np.ones([3, 16]).astype(np.float32))
|
|
288
299
|
>>> output = []
|
|
289
300
|
>>> for i in range(5):
|
|
290
301
|
... hx = net(x[i], (h, c))
|
|
@@ -293,8 +304,9 @@ class LSTMCell(RNNCellBase):
|
|
|
293
304
|
(3, 16)
|
|
294
305
|
"""
|
|
295
306
|
@_check_lstmcell_init
|
|
296
|
-
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True
|
|
297
|
-
|
|
307
|
+
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
|
|
308
|
+
dtype=mstype.float32):
|
|
309
|
+
super().__init__(input_size, hidden_size, has_bias, num_chunks=4, dtype=dtype)
|
|
298
310
|
self.support_non_tensor_inputs = True
|
|
299
311
|
|
|
300
312
|
def construct(self, x, hx):
|
|
@@ -343,15 +355,15 @@ class GRUCell(RNNCellBase):
|
|
|
343
355
|
Args:
|
|
344
356
|
input_size (int): Number of features of input.
|
|
345
357
|
hidden_size (int): Number of features of hidden layer.
|
|
346
|
-
has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: True.
|
|
358
|
+
has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: ``True`` .
|
|
359
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
347
360
|
|
|
348
361
|
Inputs:
|
|
349
|
-
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)
|
|
350
|
-
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)
|
|
351
|
-
Data type of `hx` must be the same as `x`.
|
|
362
|
+
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
|
|
363
|
+
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
|
|
352
364
|
|
|
353
365
|
Outputs:
|
|
354
|
-
- **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)
|
|
366
|
+
- **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
|
|
355
367
|
|
|
356
368
|
Raises:
|
|
357
369
|
TypeError: If `input_size`, `hidden_size` is not an int.
|
|
@@ -361,9 +373,11 @@ class GRUCell(RNNCellBase):
|
|
|
361
373
|
``Ascend`` ``GPU`` ``CPU``
|
|
362
374
|
|
|
363
375
|
Examples:
|
|
364
|
-
>>>
|
|
365
|
-
>>>
|
|
366
|
-
>>>
|
|
376
|
+
>>> import mindspore as ms
|
|
377
|
+
>>> import numpy as np
|
|
378
|
+
>>> net = ms.nn.GRUCell(10, 16)
|
|
379
|
+
>>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
|
|
380
|
+
>>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32))
|
|
367
381
|
>>> output = []
|
|
368
382
|
>>> for i in range(5):
|
|
369
383
|
... hx = net(x[i], hx)
|
|
@@ -371,8 +385,9 @@ class GRUCell(RNNCellBase):
|
|
|
371
385
|
>>> print(output[0].shape)
|
|
372
386
|
(3, 16)
|
|
373
387
|
"""
|
|
374
|
-
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True
|
|
375
|
-
|
|
388
|
+
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
|
|
389
|
+
dtype=mstype.float32):
|
|
390
|
+
super().__init__(input_size, hidden_size, has_bias, num_chunks=3, dtype=dtype)
|
|
376
391
|
|
|
377
392
|
def construct(self, x, hx):
|
|
378
393
|
_check_is_tensor('x', x, self.cls_name)
|
mindspore/nn/layer/rnns.py
CHANGED
|
@@ -35,10 +35,9 @@ from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_ce
|
|
|
35
35
|
__all__ = ['LSTM', 'GRU', 'RNN']
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
@_primexpr
|
|
39
38
|
def _init_state(shape, dtype, is_lstm):
|
|
40
|
-
hx = P.
|
|
41
|
-
cx = P.
|
|
39
|
+
hx = P.zeros(shape, dtype)
|
|
40
|
+
cx = P.zeros(shape, dtype)
|
|
42
41
|
if is_lstm:
|
|
43
42
|
return (hx, cx)
|
|
44
43
|
return hx
|
|
@@ -58,8 +57,8 @@ def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_n
|
|
|
58
57
|
@constexpr(check=False)
|
|
59
58
|
def _check_is_tensor(param_name, input_data, cls_name):
|
|
60
59
|
"""Internal function, used to check whether the input data is Tensor."""
|
|
61
|
-
if input_data is not None and not isinstance(P.typeof(input_data), mstype.
|
|
62
|
-
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.
|
|
60
|
+
if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
|
|
61
|
+
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
|
|
63
62
|
f"but got '{P.typeof(input_data)}'")
|
|
64
63
|
|
|
65
64
|
|
|
@@ -260,8 +259,8 @@ class _DynamicGRUAscend(Cell):
|
|
|
260
259
|
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
|
261
260
|
'''Dynamic GRU module on Ascend'''
|
|
262
261
|
if b_ih is None:
|
|
263
|
-
b_ih = P.
|
|
264
|
-
b_hh = P.
|
|
262
|
+
b_ih = P.zeros(w_ih.shape[0], w_ih.dtype)
|
|
263
|
+
b_hh = P.zeros(w_ih.shape[0], w_ih.dtype)
|
|
265
264
|
outputs, _, _, _, _, _ = self.gru(self.cast(x, self.dtype), \
|
|
266
265
|
self.cast(self.transpose(w_ih, (1, 0)), self.dtype), \
|
|
267
266
|
self.cast(self.transpose(w_hh, (1, 0)), self.dtype), \
|
|
@@ -344,7 +343,7 @@ class _DynamicLSTMAscend(Cell):
|
|
|
344
343
|
w_hh = self.concat_dim0((w_hh_i, w_hh_g, w_hh_f, w_hh_o))
|
|
345
344
|
weight = self.concat_dim1((w_ih, w_hh))
|
|
346
345
|
if b_ih is None:
|
|
347
|
-
bias = P.
|
|
346
|
+
bias = P.zeros(w_ih.shape[0], w_ih.dtype)
|
|
348
347
|
else:
|
|
349
348
|
b_ih_i, b_ih_f, b_ih_g, b_ih_o = self.split(b_ih)
|
|
350
349
|
b_hh_i, b_hh_f, b_hh_g, b_hh_o = self.split(b_hh)
|
|
@@ -373,7 +372,7 @@ class _RNNBase(Cell):
|
|
|
373
372
|
'''Basic class for RNN operators'''
|
|
374
373
|
|
|
375
374
|
def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
|
|
376
|
-
batch_first=False, dropout=0., bidirectional=False):
|
|
375
|
+
batch_first=False, dropout=0., bidirectional=False, dtype=mstype.float32):
|
|
377
376
|
super().__init__()
|
|
378
377
|
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
|
379
378
|
validator.check_positive_int(input_size, "input_size", self.cls_name)
|
|
@@ -436,17 +435,17 @@ class _RNNBase(Cell):
|
|
|
436
435
|
suffix = '_reverse' if direction == 1 else ''
|
|
437
436
|
|
|
438
437
|
self.w_ih_list.append(Parameter(
|
|
439
|
-
Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)
|
|
440
|
-
|
|
438
|
+
Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32),
|
|
439
|
+
dtype=dtype), name='weight_ih_l{}{}'.format(layer, suffix)))
|
|
441
440
|
self.w_hh_list.append(Parameter(
|
|
442
|
-
Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32)
|
|
443
|
-
|
|
441
|
+
Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32),
|
|
442
|
+
dtype=dtype), name='weight_hh_l{}{}'.format(layer, suffix)))
|
|
444
443
|
if has_bias:
|
|
445
444
|
self.b_ih_list.append(Parameter(
|
|
446
|
-
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
|
445
|
+
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32), dtype=dtype),
|
|
447
446
|
name='bias_ih_l{}{}'.format(layer, suffix)))
|
|
448
447
|
self.b_hh_list.append(Parameter(
|
|
449
|
-
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
|
|
448
|
+
Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32), dtype=dtype),
|
|
450
449
|
name='bias_hh_l{}{}'.format(layer, suffix)))
|
|
451
450
|
self.w_ih_list = ParameterTuple(self.w_ih_list)
|
|
452
451
|
self.w_hh_list = ParameterTuple(self.w_hh_list)
|
|
@@ -579,9 +578,7 @@ class _RNNBase(Cell):
|
|
|
579
578
|
|
|
580
579
|
class RNN(_RNNBase):
|
|
581
580
|
r"""
|
|
582
|
-
Stacked Elman RNN layers.
|
|
583
|
-
|
|
584
|
-
Apply RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input.
|
|
581
|
+
Stacked Elman RNN layers, applying RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input.
|
|
585
582
|
|
|
586
583
|
For each element in the input sequence, each layer computes the following function:
|
|
587
584
|
|
|
@@ -596,23 +593,23 @@ class RNN(_RNNBase):
|
|
|
596
593
|
Args:
|
|
597
594
|
input_size (int): Number of features of input.
|
|
598
595
|
hidden_size (int): Number of features of hidden layer.
|
|
599
|
-
num_layers (int): Number of layers of stacked RNN. Default: 1.
|
|
596
|
+
num_layers (int): Number of layers of stacked RNN. Default: ``1`` .
|
|
600
597
|
nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
|
|
601
|
-
has_bias (bool): Whether the cell has bias `
|
|
602
|
-
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
|
598
|
+
has_bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` .
|
|
599
|
+
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
|
|
603
600
|
dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
|
|
604
|
-
RNN layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
|
|
601
|
+
RNN layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0).
|
|
605
602
|
bidirectional (bool): Specifies whether it is a bidirectional RNN,
|
|
606
|
-
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
|
603
|
+
num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
|
|
604
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
607
605
|
|
|
608
606
|
Inputs:
|
|
609
607
|
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
610
608
|
shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` .
|
|
611
609
|
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
612
610
|
shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` .
|
|
613
|
-
The data type of `hx` must be the same as `x`.
|
|
614
611
|
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
|
615
|
-
Tensor of shape :math:`(batch\_size)` . Default: None.
|
|
612
|
+
Tensor of shape :math:`(batch\_size)` . Default: ``None`` .
|
|
616
613
|
This input indicates the real sequence length before padding to avoid padded elements
|
|
617
614
|
have been used to compute hidden state and affect the final output. It is recommended to
|
|
618
615
|
use this input when `x` has padding elements.
|
|
@@ -635,9 +632,11 @@ class RNN(_RNNBase):
|
|
|
635
632
|
``Ascend`` ``GPU`` ``CPU``
|
|
636
633
|
|
|
637
634
|
Examples:
|
|
638
|
-
>>>
|
|
639
|
-
>>>
|
|
640
|
-
>>>
|
|
635
|
+
>>> import mindspore as ms
|
|
636
|
+
>>> import numpy as np
|
|
637
|
+
>>> net = ms.nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
|
638
|
+
>>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
|
639
|
+
>>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
|
641
640
|
>>> output, hn = net(x, h0)
|
|
642
641
|
>>> print(output.shape)
|
|
643
642
|
(3, 5, 16)
|
|
@@ -696,22 +695,22 @@ class GRU(_RNNBase):
|
|
|
696
695
|
Args:
|
|
697
696
|
input_size (int): Number of features of input.
|
|
698
697
|
hidden_size (int): Number of features of hidden layer.
|
|
699
|
-
num_layers (int): Number of layers of stacked GRU. Default: 1.
|
|
700
|
-
has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: True.
|
|
701
|
-
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
|
698
|
+
num_layers (int): Number of layers of stacked GRU. Default: ``1`` .
|
|
699
|
+
has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: ``True`` .
|
|
700
|
+
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
|
|
702
701
|
dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
|
|
703
|
-
GRU layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
|
|
702
|
+
GRU layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0).
|
|
704
703
|
bidirectional (bool): Specifies whether it is a bidirectional GRU,
|
|
705
|
-
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
|
704
|
+
num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
|
|
705
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
706
706
|
|
|
707
707
|
Inputs:
|
|
708
708
|
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
709
|
-
shape (
|
|
709
|
+
shape (seq\_len, batch\_size, `input\_size`) or :math:`(batch\_size, seq\_len, input\_size)`.
|
|
710
710
|
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
711
|
-
shape (
|
|
712
|
-
`x`.
|
|
711
|
+
shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
|
|
713
712
|
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
|
714
|
-
Tensor of shape :math:`(\text{batch_size})`. Default: None.
|
|
713
|
+
Tensor of shape :math:`(\text{batch_size})`. Default: ``None`` .
|
|
715
714
|
This input indicates the real sequence length before padding to avoid padded elements
|
|
716
715
|
have been used to compute hidden state and affect the final output. It is recommended to
|
|
717
716
|
use this input when **x** has padding elements.
|
|
@@ -719,9 +718,9 @@ class GRU(_RNNBase):
|
|
|
719
718
|
Outputs:
|
|
720
719
|
Tuple, a tuple contains (`output`, `h_n`).
|
|
721
720
|
|
|
722
|
-
- **output** (Tensor) - Tensor of shape (
|
|
723
|
-
(
|
|
724
|
-
- **hx_n** (Tensor) - Tensor of shape (
|
|
721
|
+
- **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` or
|
|
722
|
+
:math:`(batch\_size, seq\_len, num\_directions * hidden\_size)`.
|
|
723
|
+
- **hx_n** (Tensor) - Tensor of shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
|
|
725
724
|
|
|
726
725
|
Raises:
|
|
727
726
|
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
|
|
@@ -733,9 +732,11 @@ class GRU(_RNNBase):
|
|
|
733
732
|
``Ascend`` ``GPU`` ``CPU``
|
|
734
733
|
|
|
735
734
|
Examples:
|
|
736
|
-
>>>
|
|
737
|
-
>>>
|
|
738
|
-
>>>
|
|
735
|
+
>>> import mindspore as ms
|
|
736
|
+
>>> import numpy as np
|
|
737
|
+
>>> net = ms.nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
|
738
|
+
>>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
|
739
|
+
>>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
|
739
740
|
>>> output, hn = net(x, h0)
|
|
740
741
|
>>> print(output.shape)
|
|
741
742
|
(3, 5, 16)
|
|
@@ -794,22 +795,22 @@ class LSTM(_RNNBase):
|
|
|
794
795
|
Args:
|
|
795
796
|
input_size (int): Number of features of input.
|
|
796
797
|
hidden_size (int): Number of features of hidden layer.
|
|
797
|
-
num_layers (int): Number of layers of stacked LSTM . Default: 1.
|
|
798
|
-
has_bias (bool): Whether the cell has bias `
|
|
799
|
-
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
|
|
798
|
+
num_layers (int): Number of layers of stacked LSTM . Default: ``1`` .
|
|
799
|
+
has_bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` .
|
|
800
|
+
batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
|
|
800
801
|
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
|
801
|
-
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0).
|
|
802
|
+
LSTM layer except the last layer. Default ``0`` . The range of dropout is [0.0, 1.0).
|
|
802
803
|
bidirectional (bool): Specifies whether it is a bidirectional LSTM,
|
|
803
|
-
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
|
804
|
+
num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
|
|
805
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
804
806
|
|
|
805
807
|
Inputs:
|
|
806
808
|
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
807
|
-
shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)
|
|
809
|
+
shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` .
|
|
808
810
|
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
|
809
|
-
or mindspore.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)
|
|
810
|
-
The data type of `hx` must be the same as `x`.
|
|
811
|
+
or mindspore.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` .
|
|
811
812
|
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
|
812
|
-
Tensor of shape :math:`(batch\_size)`. Default: None.
|
|
813
|
+
Tensor of shape :math:`(batch\_size)`. Default: ``None`` .
|
|
813
814
|
This input indicates the real sequence length before padding to avoid padded elements
|
|
814
815
|
have been used to compute hidden state and affect the final output. It is recommended to
|
|
815
816
|
use this input when **x** has padding elements.
|
|
@@ -831,10 +832,12 @@ class LSTM(_RNNBase):
|
|
|
831
832
|
``Ascend`` ``GPU`` ``CPU``
|
|
832
833
|
|
|
833
834
|
Examples:
|
|
834
|
-
>>>
|
|
835
|
-
>>>
|
|
836
|
-
>>>
|
|
837
|
-
>>>
|
|
835
|
+
>>> import mindspore as ms
|
|
836
|
+
>>> import numpy as np
|
|
837
|
+
>>> net = ms.nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
|
838
|
+
>>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
|
839
|
+
>>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
|
840
|
+
>>> c0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
|
838
841
|
>>> output, (hn, cn) = net(x, (h0, c0))
|
|
839
842
|
>>> print(output.shape)
|
|
840
843
|
(3, 5, 16)
|