mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Geometric Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
19
20
|
from mindspore.ops import composite as C
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -33,10 +34,10 @@ class Geometric(Distribution):
|
|
|
33
34
|
trials when the first success is achieved.
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
|
-
probs (float, list, numpy.ndarray, Tensor): The probability of success. Default: None.
|
|
37
|
-
seed (int): The seed used in sampling. Global seed is used if it is None. Default: None.
|
|
38
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
|
39
|
-
name (str): The name of the distribution. Default: 'Geometric'.
|
|
37
|
+
probs (float, list, numpy.ndarray, Tensor): The probability of success. Default: ``None`` .
|
|
38
|
+
seed (int): The seed used in sampling. Global seed is used if it is None. Default: ``None`` .
|
|
39
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
|
|
40
|
+
name (str): The name of the distribution. Default: ``'Geometric'`` .
|
|
40
41
|
|
|
41
42
|
Note:
|
|
42
43
|
`probs` must be a proper probability (0 < p < 1).
|
|
@@ -160,7 +161,6 @@ class Geometric(Distribution):
|
|
|
160
161
|
self.cast = P.Cast()
|
|
161
162
|
self.const = P.ScalarToTensor()
|
|
162
163
|
self.dtypeop = P.DType()
|
|
163
|
-
self.fill = P.Fill()
|
|
164
164
|
self.floor = P.Floor()
|
|
165
165
|
self.issubclass = inner.IsSubClass()
|
|
166
166
|
self.less = P.Less()
|
|
@@ -212,7 +212,7 @@ class Geometric(Distribution):
|
|
|
212
212
|
MODE(Geo) = 0
|
|
213
213
|
"""
|
|
214
214
|
probs1 = self._check_param_type(probs1)
|
|
215
|
-
return
|
|
215
|
+
return F.fill(self.dtype, self.shape(probs1), 0.)
|
|
216
216
|
|
|
217
217
|
def _var(self, probs1=None):
|
|
218
218
|
r"""
|
|
@@ -260,7 +260,7 @@ class Geometric(Distribution):
|
|
|
260
260
|
value = self.floor(value)
|
|
261
261
|
probs1 = self._check_param_type(probs1)
|
|
262
262
|
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
|
263
|
-
zeros =
|
|
263
|
+
zeros = F.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
|
|
264
264
|
comp = self.less(value, zeros)
|
|
265
265
|
return self.select(comp, zeros, pmf)
|
|
266
266
|
|
|
@@ -283,7 +283,7 @@ class Geometric(Distribution):
|
|
|
283
283
|
probs1 = self._check_param_type(probs1)
|
|
284
284
|
probs0 = 1.0 - probs1
|
|
285
285
|
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
|
286
|
-
zeros =
|
|
286
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
287
287
|
comp = self.less(value, zeros)
|
|
288
288
|
return self.select(comp, zeros, cdf)
|
|
289
289
|
|
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
"""Gumbel Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore import _checkparam as Validator
|
|
19
20
|
from mindspore.common import dtype as mstype
|
|
20
|
-
import mindspore.nn as nn
|
|
21
21
|
import mindspore.nn.probability.bijector as msb
|
|
22
22
|
import mindspore.nn.probability.distribution as msd
|
|
23
23
|
from .transformed_distribution import TransformedDistribution
|
|
@@ -39,9 +39,9 @@ class Gumbel(TransformedDistribution):
|
|
|
39
39
|
Args:
|
|
40
40
|
loc (int, float, list, numpy.ndarray, Tensor): The location of Gumbel distribution.
|
|
41
41
|
scale (int, float, list, numpy.ndarray, Tensor): The scale of Gumbel distribution.
|
|
42
|
-
seed (int): the seed used in sampling. The global seed is used if it is None. Default: 0.
|
|
43
|
-
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
|
44
|
-
name (str): the name of the distribution. Default: 'Gumbel'.
|
|
42
|
+
seed (int): the seed used in sampling. The global seed is used if it is None. Default: ``0`` .
|
|
43
|
+
dtype (mindspore.dtype): type of the distribution. Default: ``mstype.float32`` .
|
|
44
|
+
name (str): the name of the distribution. Default: ``'Gumbel'`` .
|
|
45
45
|
|
|
46
46
|
Note:
|
|
47
47
|
`scale` must be greater than zero.
|
|
@@ -102,8 +102,7 @@ class Gumbel(TransformedDistribution):
|
|
|
102
102
|
self.const = P.ScalarToTensor()
|
|
103
103
|
self.exp = exp_generic
|
|
104
104
|
self.expm1 = P.Expm1()
|
|
105
|
-
self.
|
|
106
|
-
self.lgamma = nn.LGamma()
|
|
105
|
+
self.lgamma = P.Lgamma()
|
|
107
106
|
self.log = log_generic
|
|
108
107
|
self.shape = P.Shape()
|
|
109
108
|
self.squeeze = P.Squeeze(0)
|
|
@@ -164,7 +163,7 @@ class Gumbel(TransformedDistribution):
|
|
|
164
163
|
"""
|
|
165
164
|
The mode of the distribution.
|
|
166
165
|
"""
|
|
167
|
-
return self.loc *
|
|
166
|
+
return self.loc * F.fill(self.parameter_type, self.shape(self.scale), 1.0)
|
|
168
167
|
|
|
169
168
|
def _sd(self):
|
|
170
169
|
r"""
|
|
@@ -174,7 +173,7 @@ class Gumbel(TransformedDistribution):
|
|
|
174
173
|
STD(X) = \frac{\pi}{\sqrt(6)} * scale
|
|
175
174
|
"""
|
|
176
175
|
scale = self.scale * \
|
|
177
|
-
|
|
176
|
+
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
|
|
178
177
|
return scale * np.pi / self.sqrt(self.const(6., mstype.float32))
|
|
179
178
|
|
|
180
179
|
def _entropy(self):
|
|
@@ -185,7 +184,7 @@ class Gumbel(TransformedDistribution):
|
|
|
185
184
|
H(X) = 1. + \log(scale) + Euler-Mascheroni_constant
|
|
186
185
|
"""
|
|
187
186
|
scale = self.scale * \
|
|
188
|
-
|
|
187
|
+
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
|
|
189
188
|
return 1. + self.log(scale) + np.euler_gamma
|
|
190
189
|
|
|
191
190
|
def _log_prob(self, value):
|
|
@@ -37,12 +37,12 @@ class HalfNormal(Distribution):
|
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
39
|
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
40
|
-
If this arg is None, then the mean of the distribution will be passed in runtime. Default: None.
|
|
40
|
+
If this arg is ``None`` , then the mean of the distribution will be passed in runtime. Default: ``None`` .
|
|
41
41
|
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
|
|
42
|
-
If this arg is None, then the sd of the distribution will be passed in runtime. Default: None.
|
|
43
|
-
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
44
|
-
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
45
|
-
name (str, optional): The name of the distribution. Default: 'HalfNormal'.
|
|
42
|
+
If this arg is ``None`` , then the sd of the distribution will be passed in runtime. Default: ``None`` .
|
|
43
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
44
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: ``mstype.float32`` .
|
|
45
|
+
name (str, optional): The name of the distribution. Default: ``'HalfNormal'`` .
|
|
46
46
|
|
|
47
47
|
Note:
|
|
48
48
|
- `sd` must be greater than zero.
|
|
@@ -35,12 +35,12 @@ class Laplace(Distribution):
|
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
37
|
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
38
|
-
If this arg is None, then the mean of the distribution will be passed in runtime. Default: None.
|
|
38
|
+
If this arg is ``None`` , then the mean of the distribution will be passed in runtime. Default: ``None`` .
|
|
39
39
|
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The scale of the distribution.
|
|
40
|
-
If this arg is None, then the scale of the distribution will be passed in runtime. Default: None.
|
|
41
|
-
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
42
|
-
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
43
|
-
name (str, optional): The name of the distribution. Default: 'Laplace'.
|
|
40
|
+
If this arg is ``None`` , then the scale of the distribution will be passed in runtime. Default: ``None`` .
|
|
41
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
42
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: ``mstype.float32`` .
|
|
43
|
+
name (str, optional): The name of the distribution. Default: ``'Laplace'`` .
|
|
44
44
|
|
|
45
45
|
Note:
|
|
46
46
|
- `sd` must be greater than zero.
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""LogNormal Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.common import dtype as mstype
|
|
19
20
|
import mindspore.nn.probability.bijector as msb
|
|
20
21
|
import mindspore.nn.probability.distribution as msd
|
|
@@ -37,12 +38,13 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
37
38
|
It is constructed as the exponential transformation of a Normal distribution.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
|
-
loc (int, float, list, numpy.ndarray, Tensor): The mean of the underlying Normal distribution.
|
|
41
|
+
loc (int, float, list, numpy.ndarray, Tensor): The mean of the underlying Normal distribution.
|
|
42
|
+
Default: ``None`` .
|
|
41
43
|
scale (int, float, list, numpy.ndarray, Tensor): The standard deviation of the underlying
|
|
42
|
-
Normal distribution. Default: None.
|
|
43
|
-
seed (int): the seed used in sampling. The global seed is used if it is None. Default: 0.
|
|
44
|
-
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
|
45
|
-
name (str): the name of the distribution. Default: 'LogNormal'.
|
|
44
|
+
Normal distribution. Default: ``None`` .
|
|
45
|
+
seed (int): the seed used in sampling. The global seed is used if it is None. Default: ``0`` .
|
|
46
|
+
dtype (mindspore.dtype): type of the distribution. Default: ``mstype.float32`` .
|
|
47
|
+
name (str): the name of the distribution. Default: ``'LogNormal'`` .
|
|
46
48
|
|
|
47
49
|
Note:
|
|
48
50
|
`scale` must be greater than zero.
|
|
@@ -100,7 +102,6 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
100
102
|
self.expm1 = P.Expm1()
|
|
101
103
|
self.log = log_generic
|
|
102
104
|
self.erf = P.Erf()
|
|
103
|
-
self.fill = P.Fill()
|
|
104
105
|
self.greater = P.Greater()
|
|
105
106
|
self.select = P.Select()
|
|
106
107
|
self.shape = P.Shape()
|
|
@@ -201,7 +202,7 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
201
202
|
cdf = self.distribution("cdf", inverse_value, mean, sd)
|
|
202
203
|
|
|
203
204
|
# to increase numerical stability, set cdf = 0 when value <= 0
|
|
204
|
-
zeros =
|
|
205
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
205
206
|
|
|
206
207
|
return self.select(self.greater(value, 0.), cdf, zeros)
|
|
207
208
|
|
|
@@ -229,8 +230,8 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
229
230
|
dist (str): The type of the distributions. Should be "LogNormal" in this case.
|
|
230
231
|
loc_b (Tensor): The loc of distribution b.
|
|
231
232
|
scale_b (Tensor): The scale of distribution b.
|
|
232
|
-
loc_a (Tensor): The loc of distribution a. Default: None
|
|
233
|
-
scale_a (Tensor): The scale of distribution a. Default: None
|
|
233
|
+
loc_a (Tensor): The loc of distribution a. Default: ``None``.
|
|
234
|
+
scale_a (Tensor): The scale of distribution a. Default: ``None``.
|
|
234
235
|
"""
|
|
235
236
|
check_distribution_name(dist, 'LogNormal')
|
|
236
237
|
return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a)
|
|
@@ -243,8 +244,8 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
243
244
|
dist (str): The type of the distributions. Should be "LogNormal" in this case.
|
|
244
245
|
loc_b (Tensor): The loc of distribution b.
|
|
245
246
|
scale_b (Tensor): The scale of distribution b.
|
|
246
|
-
loc_a (Tensor): The loc of distribution a. Default: None
|
|
247
|
-
scale_a (Tensor): The scale of distribution a. Default: None
|
|
247
|
+
loc_a (Tensor): The loc of distribution a. Default: ``None``.
|
|
248
|
+
scale_a (Tensor): The scale of distribution a. Default: ``None``.
|
|
248
249
|
|
|
249
250
|
.. math::
|
|
250
251
|
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Logistic Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
@@ -35,11 +36,11 @@ class Logistic(Distribution):
|
|
|
35
36
|
where :math:`a, b` are loc and scale parameter respectively.
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
|
-
loc (float, list, numpy.ndarray, Tensor): The location of the Logistic distribution. Default: None.
|
|
39
|
-
scale (float, list, numpy.ndarray, Tensor): The scale of the Logistic distribution. Default: None.
|
|
40
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
41
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
42
|
-
name (str): The name of the distribution. Default: 'Logistic'.
|
|
39
|
+
loc (float, list, numpy.ndarray, Tensor): The location of the Logistic distribution. Default: ``None`` .
|
|
40
|
+
scale (float, list, numpy.ndarray, Tensor): The scale of the Logistic distribution. Default: ``None`` .
|
|
41
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
42
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
43
|
+
name (str): The name of the distribution. Default: ``'Logistic'`` .
|
|
43
44
|
|
|
44
45
|
Note:
|
|
45
46
|
`scale` must be greater than zero.
|
|
@@ -153,7 +154,6 @@ class Logistic(Distribution):
|
|
|
153
154
|
self.dtypeop = P.DType()
|
|
154
155
|
self.exp = exp_generic
|
|
155
156
|
self.expm1 = P.Expm1()
|
|
156
|
-
self.fill = P.Fill()
|
|
157
157
|
self.less = P.Less()
|
|
158
158
|
self.log = log_generic
|
|
159
159
|
self.log1p = P.Log1p()
|
|
@@ -171,7 +171,7 @@ class Logistic(Distribution):
|
|
|
171
171
|
|
|
172
172
|
self.threshold = np.log(np.finfo(np.float32).eps) + 1.
|
|
173
173
|
self.tiny = np.finfo(np.float).tiny
|
|
174
|
-
self.sd_const = np.pi/np.sqrt(3)
|
|
174
|
+
self.sd_const = np.pi / np.sqrt(3)
|
|
175
175
|
|
|
176
176
|
def _softplus(self, x):
|
|
177
177
|
too_small = self.less(x, self.threshold)
|
|
@@ -179,7 +179,7 @@ class Logistic(Distribution):
|
|
|
179
179
|
too_small_value = self.exp(x)
|
|
180
180
|
too_large_value = x
|
|
181
181
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
182
|
-
ones =
|
|
182
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
183
183
|
x = self.select(too_small_or_too_large, ones, x)
|
|
184
184
|
y = self.log(self.exp(x) + 1.0)
|
|
185
185
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
|
@@ -36,11 +36,12 @@ class Normal(Distribution):
|
|
|
36
36
|
the standard deviation of the normal distribution respectively.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
mean (int, float, list, numpy.ndarray, Tensor): The mean of the Normal distribution. Default: None.
|
|
40
|
-
sd (int, float, list, numpy.ndarray, Tensor): The standard deviation of the Normal distribution.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
39
|
+
mean (int, float, list, numpy.ndarray, Tensor): The mean of the Normal distribution. Default: ``None`` .
|
|
40
|
+
sd (int, float, list, numpy.ndarray, Tensor): The standard deviation of the Normal distribution.
|
|
41
|
+
Default: ``None`` .
|
|
42
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
43
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
44
|
+
name (str): The name of the distribution. Default: ``'Normal'`` .
|
|
44
45
|
|
|
45
46
|
Note:
|
|
46
47
|
`sd` must be greater than zero.
|
|
@@ -15,8 +15,8 @@
|
|
|
15
15
|
"""Poisson Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
|
-
import mindspore.nn as nn
|
|
20
20
|
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
@@ -36,10 +36,10 @@ class Poisson(Distribution):
|
|
|
36
36
|
where :math:`\lambda` is the rate of the distribution.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
39
|
-
rate (list, numpy.ndarray, Tensor): The rate of the Poisson distribution. Default: None.
|
|
40
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
41
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
42
|
-
name (str): The name of the distribution. Default: 'Poisson'.
|
|
39
|
+
rate (list, numpy.ndarray, Tensor): The rate of the Poisson distribution. Default: ``None`` .
|
|
40
|
+
seed (int): The seed used in sampling. The global seed is used if it is ``None`` . Default: ``None`` .
|
|
41
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
42
|
+
name (str): The name of the distribution. Default: ``'Poisson'`` .
|
|
43
43
|
|
|
44
44
|
Note:
|
|
45
45
|
`rate` must be strictly greater than 0.
|
|
@@ -150,12 +150,11 @@ class Poisson(Distribution):
|
|
|
150
150
|
self.floor = P.Floor()
|
|
151
151
|
self.dtypeop = P.DType()
|
|
152
152
|
self.shape = P.Shape()
|
|
153
|
-
self.fill = P.Fill()
|
|
154
153
|
self.less = P.Less()
|
|
155
154
|
self.equal = P.Equal()
|
|
156
155
|
self.select = P.Select()
|
|
157
|
-
self.lgamma =
|
|
158
|
-
self.igamma =
|
|
156
|
+
self.lgamma = P.Lgamma()
|
|
157
|
+
self.igamma = P.Igamma()
|
|
159
158
|
self.poisson = C.poisson
|
|
160
159
|
|
|
161
160
|
@property
|
|
@@ -229,8 +228,8 @@ class Poisson(Distribution):
|
|
|
229
228
|
value = self.cast(value, self.dtype)
|
|
230
229
|
rate = self._check_param_type(rate)
|
|
231
230
|
log_rate = self.log(rate)
|
|
232
|
-
zeros =
|
|
233
|
-
inf =
|
|
231
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
232
|
+
inf = F.fill(self.dtypeop(value), self.shape(value), np.inf)
|
|
234
233
|
safe_x = self.select(self.less(value, zeros), zeros, value)
|
|
235
234
|
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
|
236
235
|
comp = self.equal(value, safe_x)
|
|
@@ -255,7 +254,7 @@ class Poisson(Distribution):
|
|
|
255
254
|
value = self._check_value(value, 'value')
|
|
256
255
|
value = self.cast(value, self.dtype)
|
|
257
256
|
rate = self._check_param_type(rate)
|
|
258
|
-
zeros =
|
|
257
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
259
258
|
comp = self.less(value, zeros)
|
|
260
259
|
safe_x = self.select(comp, zeros, value)
|
|
261
260
|
cdf = 1. - self.igamma(1. + safe_x, rate)
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from __future__ import division
|
|
18
18
|
import numpy as np
|
|
19
|
-
import mindspore.nn as nn
|
|
20
19
|
from mindspore.ops import operations as P
|
|
21
20
|
from mindspore import _checkparam as Validator
|
|
22
21
|
from mindspore.common import dtype as mstype
|
|
@@ -39,14 +38,14 @@ class StudentT(Distribution):
|
|
|
39
38
|
|
|
40
39
|
Args:
|
|
41
40
|
df (Union[int, float, list, numpy.ndarray, Tensor], optional): The degrees of freedom.
|
|
42
|
-
If this arg is None, then the df of the distribution will be passed in runtime. Default: None.
|
|
41
|
+
If this arg is ``None`` , then the df of the distribution will be passed in runtime. Default: ``None`` .
|
|
43
42
|
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
44
|
-
If this arg is None, then the df of the distribution will be passed in runtime. Default: None.
|
|
43
|
+
If this arg is ``None`` , then the df of the distribution will be passed in runtime. Default: ``None`` .
|
|
45
44
|
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
|
|
46
|
-
If this arg is None, then the sd of the distribution will be passed in runtime. Default: None.
|
|
47
|
-
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
48
|
-
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
49
|
-
name (str, optional): The name of the distribution. Default: 'StudentT'.
|
|
45
|
+
If this arg is ``None`` , then the sd of the distribution will be passed in runtime. Default: ``None`` .
|
|
46
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
47
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: ``mstype.float32`` .
|
|
48
|
+
name (str, optional): The name of the distribution. Default: ``'StudentT'`` .
|
|
50
49
|
|
|
51
50
|
Note:
|
|
52
51
|
- `df` must be greater than zero.
|
|
@@ -123,7 +122,7 @@ class StudentT(Distribution):
|
|
|
123
122
|
self.abs = P.Abs()
|
|
124
123
|
self.half = 0.5
|
|
125
124
|
self.half_log_pi = 0.5 * np.log(np.pi)
|
|
126
|
-
self.lgamma =
|
|
125
|
+
self.lgamma = P.Lgamma()
|
|
127
126
|
|
|
128
127
|
def _log_prob(self, value, df=None, mean=None, sd=None):
|
|
129
128
|
r"""
|
|
@@ -146,5 +145,5 @@ class StudentT(Distribution):
|
|
|
146
145
|
y = (value - mean) / sd
|
|
147
146
|
log_unnormalized_prob = -0.5 * (df + 1.) * self.log1p(y**2. / df)
|
|
148
147
|
log_normalization = self.log(self.abs(sd)) + 0.5 * self.log(df) + self.half_log_pi + \
|
|
149
|
-
|
|
148
|
+
self.lgamma(self.half * df) - self.lgamma(self.half * (df + 1.))
|
|
150
149
|
return log_unnormalized_prob - log_normalization
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore import _checkparam as validator
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
|
+
from mindspore.ops import functional as F
|
|
19
20
|
from mindspore.common import dtype as mstype
|
|
20
21
|
import mindspore.nn as nn
|
|
21
22
|
from .distribution import Distribution
|
|
@@ -35,10 +36,10 @@ class TransformedDistribution(Distribution):
|
|
|
35
36
|
Args:
|
|
36
37
|
bijector (Bijector): The transformation to perform.
|
|
37
38
|
distribution (Distribution): The original distribution. Must be a float dtype.
|
|
38
|
-
seed (int): The seed is used in sampling. The global seed is used if it is None. Default: None.
|
|
39
|
+
seed (int): The seed is used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
39
40
|
If this seed is given when a TransformedDistribution object is initialized, the object's sampling function
|
|
40
41
|
will use this seed; elsewise, the underlying distribution's seed will be used.
|
|
41
|
-
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
|
|
42
|
+
name (str): The name of the transformed distribution. Default: ``'transformed_distribution'`` .
|
|
42
43
|
|
|
43
44
|
Note:
|
|
44
45
|
The arguments used to initialize the original distribution cannot be None.
|
|
@@ -125,7 +126,6 @@ class TransformedDistribution(Distribution):
|
|
|
125
126
|
self.cast_base = P.Cast()
|
|
126
127
|
self.equal_base = P.Equal()
|
|
127
128
|
self.select_base = P.Select()
|
|
128
|
-
self.fill_base = P.Fill()
|
|
129
129
|
|
|
130
130
|
# broadcast bijector batch_shape and distribution batch_shape
|
|
131
131
|
self._broadcast_shape = self._broadcast_bijector_dist()
|
|
@@ -176,9 +176,9 @@ class TransformedDistribution(Distribution):
|
|
|
176
176
|
"""
|
|
177
177
|
if self.batch_shape is None or self.bijector.batch_shape is None:
|
|
178
178
|
return None
|
|
179
|
-
bijector_shape_tensor =
|
|
179
|
+
bijector_shape_tensor = F.fill(
|
|
180
180
|
self.dtype, self.bijector.batch_shape, 0.0)
|
|
181
|
-
dist_shape_tensor =
|
|
181
|
+
dist_shape_tensor = F.fill(self.dtype, self.batch_shape, 0.0)
|
|
182
182
|
return (bijector_shape_tensor + dist_shape_tensor).shape
|
|
183
183
|
|
|
184
184
|
def _cdf(self, value, *args, **kwargs):
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Uniform Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
|
+
from mindspore.ops import functional as F
|
|
17
18
|
from mindspore.ops import operations as P
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
@@ -35,11 +36,11 @@ class Uniform(Distribution):
|
|
|
35
36
|
where :math:`a, b` are the lower and upper bound respectively.
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
|
-
low (int, float, list, numpy.ndarray, Tensor): The lower bound of the distribution. Default: None.
|
|
39
|
-
high (int, float, list, numpy.ndarray, Tensor): The upper bound of the distribution. Default: None.
|
|
40
|
-
seed (int): The seed uses in sampling. The global seed is used if it is None. Default: None.
|
|
41
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
42
|
-
name (str): The name of the distribution. Default: 'Uniform'.
|
|
39
|
+
low (int, float, list, numpy.ndarray, Tensor): The lower bound of the distribution. Default: ``None`` .
|
|
40
|
+
high (int, float, list, numpy.ndarray, Tensor): The upper bound of the distribution. Default: ``None`` .
|
|
41
|
+
seed (int): The seed uses in sampling. The global seed is used if it is ``None`` . Default: ``None`` .
|
|
42
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
43
|
+
name (str): The name of the distribution. Default: ``'Uniform'`` .
|
|
43
44
|
|
|
44
45
|
Note:
|
|
45
46
|
`low` must be strictly less than `high`.
|
|
@@ -170,7 +171,6 @@ class Uniform(Distribution):
|
|
|
170
171
|
self.cast = P.Cast()
|
|
171
172
|
self.const = P.ScalarToTensor()
|
|
172
173
|
self.dtypeop = P.DType()
|
|
173
|
-
self.fill = P.Fill()
|
|
174
174
|
self.less = P.Less()
|
|
175
175
|
self.lessequal = P.LessEqual()
|
|
176
176
|
self.logicaland = P.LogicalAnd()
|
|
@@ -287,10 +287,10 @@ class Uniform(Distribution):
|
|
|
287
287
|
value = self._check_value(value, 'value')
|
|
288
288
|
value = self.cast(value, self.dtype)
|
|
289
289
|
low, high = self._check_param_type(low, high)
|
|
290
|
-
neg_ones =
|
|
290
|
+
neg_ones = F.fill(self.dtype, self.shape(value), -1.0)
|
|
291
291
|
prob = self.exp(neg_ones * self.log(high - low))
|
|
292
292
|
broadcast_shape = self.shape(prob)
|
|
293
|
-
zeros =
|
|
293
|
+
zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
|
294
294
|
comp_lo = self.less(value, low)
|
|
295
295
|
comp_hi = self.lessequal(value, high)
|
|
296
296
|
less_than_low = self.select(comp_lo, zeros, prob)
|
|
@@ -316,7 +316,7 @@ class Uniform(Distribution):
|
|
|
316
316
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
|
317
317
|
comp = self.logicaland(self.lessequal(
|
|
318
318
|
low_b, low_a), self.lessequal(high_a, high_b))
|
|
319
|
-
inf =
|
|
319
|
+
inf = F.fill(self.dtypeop(kl), self.shape(kl), np.inf)
|
|
320
320
|
return self.select(comp, kl, inf)
|
|
321
321
|
|
|
322
322
|
def _cdf(self, value, low=None, high=None):
|
|
@@ -338,8 +338,8 @@ class Uniform(Distribution):
|
|
|
338
338
|
low, high = self._check_param_type(low, high)
|
|
339
339
|
prob = (value - low) / (high - low)
|
|
340
340
|
broadcast_shape = self.shape(prob)
|
|
341
|
-
zeros =
|
|
342
|
-
ones =
|
|
341
|
+
zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
|
342
|
+
ones = F.fill(self.dtypeop(prob), broadcast_shape, 1.0)
|
|
343
343
|
comp_lo = self.less(value, low)
|
|
344
344
|
comp_hi = self.less(value, high)
|
|
345
345
|
less_than_low = self.select(comp_lo, zeros, prob)
|
|
@@ -32,9 +32,9 @@ class TensorArray(Cell):
|
|
|
32
32
|
Args:
|
|
33
33
|
dtype (mindspore.dtype): the data type in the TensorArray.
|
|
34
34
|
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
|
|
35
|
-
dynamic_size (bool): if true, the size of TensorArray can be increased. Default: True.
|
|
35
|
+
dynamic_size (bool): if ``true`` , the size of TensorArray can be increased. Default: ``True`` .
|
|
36
36
|
size (int): if dynamic_size=False, `size` means the max_size of the TensorArray.
|
|
37
|
-
name (string): the name of this TensorArray. Default: "TA".
|
|
37
|
+
name (string): the name of this TensorArray. Default: ``"TA"`` .
|
|
38
38
|
|
|
39
39
|
Supported Platforms:
|
|
40
40
|
``GPU`` ``CPU``
|
mindspore/nn/sparse/sparse.py
CHANGED
|
@@ -45,7 +45,7 @@ class SparseToDense(Cell):
|
|
|
45
45
|
TypeError: If `sparse_tensor.shape` is not a tuple.
|
|
46
46
|
|
|
47
47
|
Supported Platforms:
|
|
48
|
-
``
|
|
48
|
+
``CPU``
|
|
49
49
|
|
|
50
50
|
Examples:
|
|
51
51
|
>>> import mindspore as ms
|
|
@@ -90,8 +90,8 @@ class SparseTensorDenseMatmul(Cell):
|
|
|
90
90
|
The rank of sparse matrix and dense matrix must be equal to `2`.
|
|
91
91
|
|
|
92
92
|
Args:
|
|
93
|
-
adjoint_st (bool): If true, sparse tensor is transposed before multiplication. Default: False.
|
|
94
|
-
adjoint_dt (bool): If true, dense tensor is transposed before multiplication. Default: False.
|
|
93
|
+
adjoint_st (bool): If ``true`` , sparse tensor is transposed before multiplication. Default: ``False`` .
|
|
94
|
+
adjoint_dt (bool): If ``true`` , dense tensor is transposed before multiplication. Default: ``False`` .
|
|
95
95
|
|
|
96
96
|
Inputs:
|
|
97
97
|
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
|
@@ -101,15 +101,15 @@ class SparseTensorDenseMatmul(Cell):
|
|
|
101
101
|
- **sparse_shape** (tuple) - A positive int tuple which specifies the shape of sparse tensor,
|
|
102
102
|
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
|
|
103
103
|
- **dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
|
|
104
|
-
If `adjoint_st` is False and `adjoint_dt` is False, the shape must be :math:`(C, M)`.
|
|
105
|
-
If `adjoint_st` is False and `adjoint_dt` is True, the shape must be :math:`(M, C)`.
|
|
106
|
-
If `adjoint_st` is True and `adjoint_dt` is False, the shape must be :math:`(N, M)`.
|
|
107
|
-
If `adjoint_st` is True and `adjoint_dt` is True, the shape must be :math:`(M, N)`.
|
|
104
|
+
If `adjoint_st` is ``False`` and `adjoint_dt` is ``False`` , the shape must be :math:`(C, M)`.
|
|
105
|
+
If `adjoint_st` is ``False`` and `adjoint_dt` is ``True`` , the shape must be :math:`(M, C)`.
|
|
106
|
+
If `adjoint_st` is ``True`` and `adjoint_dt` is ``False`` , the shape must be :math:`(N, M)`.
|
|
107
|
+
If `adjoint_st` is ``True`` and `adjoint_dt` is ``True`` , the shape must be :math:`(M, N)`.
|
|
108
108
|
|
|
109
109
|
Outputs:
|
|
110
110
|
Tensor, the dtype is the same as `values`.
|
|
111
|
-
If `adjoint_st` is False, the shape is :math:`(N, M)`.
|
|
112
|
-
If `adjoint_st` is True, the shape is :math:`(C, M)`.
|
|
111
|
+
If `adjoint_st` is ``False`` , the shape is :math:`(N, M)`.
|
|
112
|
+
If `adjoint_st` is ``True`` , the shape is :math:`(C, M)`.
|
|
113
113
|
|
|
114
114
|
Raises:
|
|
115
115
|
TypeError: If the type of `adjoint_st` or `adjoint_dt` is not bool, or the dtype of `indices`,
|