mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Softplus Bijector"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.nn.layer.activation import LogSigmoid
|
|
19
20
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
20
21
|
from .bijector import Bijector
|
|
@@ -31,8 +32,8 @@ class Softplus(Bijector):
|
|
|
31
32
|
where k is the sharpness factor.
|
|
32
33
|
|
|
33
34
|
Args:
|
|
34
|
-
sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
|
|
35
|
-
name (str): The name of the Bijector. Default: 'Softplus'.
|
|
35
|
+
sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
|
|
36
|
+
name (str): The name of the Bijector. Default: ``'Softplus'`` .
|
|
36
37
|
|
|
37
38
|
Note:
|
|
38
39
|
The dtype of `sharpness` must be float.
|
|
@@ -84,7 +85,6 @@ class Softplus(Bijector):
|
|
|
84
85
|
self.abs = P.Abs()
|
|
85
86
|
self.dtypeop = P.DType()
|
|
86
87
|
self.cast = P.Cast()
|
|
87
|
-
self.fill = P.Fill()
|
|
88
88
|
self.greater = P.Greater()
|
|
89
89
|
self.less = P.Less()
|
|
90
90
|
self.log_sigmoid = LogSigmoid()
|
|
@@ -103,7 +103,7 @@ class Softplus(Bijector):
|
|
|
103
103
|
too_large = self.greater(x, -self.threshold)
|
|
104
104
|
too_small_value = self.exp(x)
|
|
105
105
|
too_large_value = x
|
|
106
|
-
ones =
|
|
106
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
107
107
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
108
108
|
x = self.select(too_small_or_too_large, ones, x)
|
|
109
109
|
y = self.log(self.exp(x) + 1.0)
|
|
@@ -119,7 +119,7 @@ class Softplus(Bijector):
|
|
|
119
119
|
too_large = self.greater(x, (-1) * self.threshold)
|
|
120
120
|
too_small_value = self.log(x)
|
|
121
121
|
too_large_value = x
|
|
122
|
-
ones =
|
|
122
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
123
123
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
124
124
|
x = self.select(too_small_or_too_large, ones, x)
|
|
125
125
|
y = x + self.log(self.abs(self.expm1((-1)*x)))
|
|
@@ -27,8 +27,10 @@ class WithBNNLossCell(Cell):
|
|
|
27
27
|
Args:
|
|
28
28
|
backbone (Cell): The target network.
|
|
29
29
|
loss_fn (Cell): The loss function used to compute loss.
|
|
30
|
-
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by the loss function.
|
|
31
|
-
|
|
30
|
+
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by the loss function.
|
|
31
|
+
Default: ``1`` .
|
|
32
|
+
bnn_factor(int, float): The coefficient of KL loss, which is the KL divergence of Bayesian layer.
|
|
33
|
+
Default: ``1`` .
|
|
32
34
|
|
|
33
35
|
Inputs:
|
|
34
36
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
@@ -157,11 +157,11 @@ class ConvReparam(_ConvVariational):
|
|
|
157
157
|
stride(Union[int, tuple[int]]): The distance of kernel moving,
|
|
158
158
|
an integer number represents that the height and width of movement
|
|
159
159
|
are both strides, or a tuple of two integers numbers represents that
|
|
160
|
-
height and width of movement respectively. Default: 1.
|
|
160
|
+
height and width of movement respectively. Default: ``1`` .
|
|
161
161
|
pad_mode (str): Specifies the padding mode. The optional values are
|
|
162
|
-
"same", "valid", and "pad". Default: "same".
|
|
162
|
+
``"same"`` , ``"valid"`` , and ``"pad"`` . Default: ``"same"`` .
|
|
163
163
|
|
|
164
|
-
- same
|
|
164
|
+
- ``"same"``: Adopts the way of completion. Output height and width
|
|
165
165
|
will be the same as the input.
|
|
166
166
|
The total number of padding will be calculated for in horizontal and
|
|
167
167
|
vertical directions and evenly distributed to top and bottom,
|
|
@@ -169,43 +169,43 @@ class ConvReparam(_ConvVariational):
|
|
|
169
169
|
will be done from the bottom and the right side. If this mode
|
|
170
170
|
is set, `padding` must be 0.
|
|
171
171
|
|
|
172
|
-
- valid
|
|
172
|
+
- ``"valid"``: Adopts the way of discarding. The possible largest
|
|
173
173
|
height and width of the output will be returned without padding.
|
|
174
174
|
Extra pixels will be discarded. If this mode is set, `padding`
|
|
175
175
|
must be 0.
|
|
176
176
|
|
|
177
|
-
- pad
|
|
177
|
+
- ``"pad"``: Implicit paddings on both sides of the input. The number
|
|
178
178
|
of `padding` will be padded to the input Tensor borders.
|
|
179
179
|
`padding` must be greater than or equal to 0.
|
|
180
180
|
|
|
181
181
|
padding (Union[int, tuple[int]]): Implicit paddings on both sides of
|
|
182
|
-
the input. Default: 0.
|
|
182
|
+
the input. Default: ``0`` .
|
|
183
183
|
dilation (Union[int, tuple[int]]): The data type is an integer or a tuple
|
|
184
184
|
of 2 integers. This parameter specifies the dilation rate of the
|
|
185
185
|
dilated convolution. If set to be :math:`k > 1`,
|
|
186
186
|
there will be :math:`k - 1` pixels skipped for each sampling
|
|
187
187
|
location. Its value must be greater or equal to 1 and bounded
|
|
188
|
-
by the height and width of the input. Default: 1.
|
|
188
|
+
by the height and width of the input. Default: ``1`` .
|
|
189
189
|
group (int): Splits filter into groups, `in_ channels` and
|
|
190
190
|
`out_channels` must be divisible by the number of groups.
|
|
191
|
-
Default: 1.
|
|
191
|
+
Default: ``1`` .
|
|
192
192
|
has_bias (bool): Specifies whether the layer uses a bias vector.
|
|
193
|
-
Default: False.
|
|
193
|
+
Default: ``False`` .
|
|
194
194
|
weight_prior_fn (Cell): The prior distribution for weight.
|
|
195
195
|
It must return a mindspore distribution instance.
|
|
196
|
-
Default: NormalPrior. (which creates an instance of standard
|
|
196
|
+
Default: ``NormalPrior`` . (which creates an instance of standard
|
|
197
197
|
normal distribution). The current version only supports normal distribution.
|
|
198
198
|
weight_posterior_fn (function): The posterior distribution for sampling weight.
|
|
199
199
|
It must be a function handle which returns a mindspore
|
|
200
|
-
distribution instance. Default: normal_post_fn.
|
|
200
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
201
201
|
The current version only supports normal distribution.
|
|
202
202
|
bias_prior_fn (Cell): The prior distribution for bias vector. It must return
|
|
203
|
-
a mindspore distribution. Default: NormalPrior(which creates an
|
|
203
|
+
a mindspore distribution. Default: ``NormalPrior`` (which creates an
|
|
204
204
|
instance of standard normal distribution). The current version
|
|
205
205
|
only supports normal distribution.
|
|
206
206
|
bias_posterior_fn (function): The posterior distribution for sampling bias vector.
|
|
207
207
|
It must be a function handle which returns a mindspore
|
|
208
|
-
distribution instance. Default: normal_post_fn.
|
|
208
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
209
209
|
The current version only supports normal distribution.
|
|
210
210
|
|
|
211
211
|
Inputs:
|
|
@@ -136,23 +136,23 @@ class DenseReparam(_DenseVariational):
|
|
|
136
136
|
activation (str, Cell): A regularization function applied to the output of the layer.
|
|
137
137
|
The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()).
|
|
138
138
|
Note that if the type of activation is Cell, it must be instantiated beforehand.
|
|
139
|
-
Default: None.
|
|
140
|
-
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
|
139
|
+
Default: ``None`` .
|
|
140
|
+
has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` .
|
|
141
141
|
weight_prior_fn (Cell): The prior distribution for weight.
|
|
142
142
|
It must return a mindspore distribution instance.
|
|
143
|
-
Default: NormalPrior. (which creates an instance of standard
|
|
143
|
+
Default: ``NormalPrior`` . (which creates an instance of standard
|
|
144
144
|
normal distribution). The current version only supports normal distribution.
|
|
145
145
|
weight_posterior_fn (function): The posterior distribution for sampling weight.
|
|
146
146
|
It must be a function handle which returns a mindspore
|
|
147
|
-
distribution instance. Default: normal_post_fn.
|
|
147
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
148
148
|
The current version only supports normal distribution.
|
|
149
149
|
bias_prior_fn (Cell): The prior distribution for bias vector. It must return
|
|
150
|
-
a mindspore distribution. Default: NormalPrior(which creates an
|
|
150
|
+
a mindspore distribution. Default: ``NormalPrior`` (which creates an
|
|
151
151
|
instance of standard normal distribution). The current version
|
|
152
152
|
only supports normal distribution.
|
|
153
153
|
bias_posterior_fn (function): The posterior distribution for sampling bias vector.
|
|
154
154
|
It must be a function handle which returns a mindspore
|
|
155
|
-
distribution instance. Default: normal_post_fn.
|
|
155
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
156
156
|
The current version only supports normal distribution.
|
|
157
157
|
|
|
158
158
|
Inputs:
|
|
@@ -230,23 +230,23 @@ class DenseLocalReparam(_DenseVariational):
|
|
|
230
230
|
activation (str, Cell): A regularization function applied to the output of the layer.
|
|
231
231
|
The type of `activation` can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()).
|
|
232
232
|
Note that if the type of activation is Cell, it must be instantiated beforehand.
|
|
233
|
-
Default: None.
|
|
234
|
-
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
|
233
|
+
Default: ``None`` .
|
|
234
|
+
has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` .
|
|
235
235
|
weight_prior_fn (Cell): The prior distribution for weight.
|
|
236
236
|
It must return a mindspore distribution instance.
|
|
237
|
-
Default: NormalPrior. (which creates an instance of standard
|
|
237
|
+
Default: ``NormalPrior`` . (which creates an instance of standard
|
|
238
238
|
normal distribution). The current version only supports normal distribution.
|
|
239
239
|
weight_posterior_fn (function): The posterior distribution for sampling weight.
|
|
240
240
|
It must be a function handle which returns a mindspore
|
|
241
|
-
distribution instance. Default: normal_post_fn.
|
|
241
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
242
242
|
The current version only supports normal distribution.
|
|
243
243
|
bias_prior_fn (Cell): The prior distribution for bias vector. It must return
|
|
244
|
-
a mindspore distribution. Default: NormalPrior(which creates an
|
|
244
|
+
a mindspore distribution. Default: ``NormalPrior`` (which creates an
|
|
245
245
|
instance of standard normal distribution). The current version
|
|
246
246
|
only supports normal distribution.
|
|
247
247
|
bias_posterior_fn (function): The posterior distribution for sampling bias vector.
|
|
248
248
|
It must be a function handle which returns a mindspore
|
|
249
|
-
distribution instance. Default: normal_post_fn.
|
|
249
|
+
distribution instance. Default: ``normal_post_fn`` .
|
|
250
250
|
The current version only supports normal distribution.
|
|
251
251
|
|
|
252
252
|
Inputs:
|
|
@@ -30,9 +30,9 @@ class NormalPrior(Cell):
|
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
32
|
dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
|
|
33
|
-
Default: mindspore.float32.
|
|
34
|
-
mean (int, float): Mean of normal distribution. Default: 0.
|
|
35
|
-
std (int, float): Standard deviation of normal distribution. Default: 0.1.
|
|
33
|
+
Default: ``mindspore.float32`` .
|
|
34
|
+
mean (int, float): Mean of normal distribution. Default: ``0`` .
|
|
35
|
+
std (int, float): Standard deviation of normal distribution. Default: ``0.1`` .
|
|
36
36
|
|
|
37
37
|
Returns:
|
|
38
38
|
Cell, a normal distribution.
|
|
@@ -56,12 +56,13 @@ class NormalPosterior(Cell):
|
|
|
56
56
|
name (str): Name prepended to trainable parameter.
|
|
57
57
|
shape (list, tuple): Shape of the mean and standard deviation.
|
|
58
58
|
dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
|
|
59
|
-
Default: mindspore.float32.
|
|
60
|
-
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
|
|
61
|
-
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
|
|
62
|
-
untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters.
|
|
59
|
+
Default: ``mindspore.float32`` .
|
|
60
|
+
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: ``0`` .
|
|
61
|
+
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: ``0.1`` .
|
|
62
|
+
untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters.
|
|
63
|
+
Default: ``-5`` .
|
|
63
64
|
untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters.
|
|
64
|
-
Default: 0.1.
|
|
65
|
+
Default: ``0.1`` .
|
|
65
66
|
|
|
66
67
|
Returns:
|
|
67
68
|
Cell, a normal distribution.
|
|
@@ -15,8 +15,17 @@
|
|
|
15
15
|
"""Utility functions to help distribution class."""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
20
|
+
from mindspore.ops.primitive import constexpr
|
|
19
21
|
from mindspore.common import dtype as mstype
|
|
22
|
+
from .utils import CheckTensor
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@constexpr(check=False)
|
|
26
|
+
def _check_tensor(x, name):
|
|
27
|
+
CheckTensor()(x, name)
|
|
28
|
+
return x
|
|
20
29
|
|
|
21
30
|
|
|
22
31
|
def exp_generic(input_x):
|
|
@@ -44,7 +53,6 @@ def log_generic(input_x):
|
|
|
44
53
|
log = P.Log()
|
|
45
54
|
less = P.Less()
|
|
46
55
|
lessequal = P.LessEqual()
|
|
47
|
-
fill = P.Fill()
|
|
48
56
|
cast = P.Cast()
|
|
49
57
|
dtype = P.DType()
|
|
50
58
|
shape = P.Shape()
|
|
@@ -53,8 +61,8 @@ def log_generic(input_x):
|
|
|
53
61
|
|
|
54
62
|
if not checktype(dtype(input_x), mstype.float_):
|
|
55
63
|
input_x = cast(input_x, mstype.float32)
|
|
56
|
-
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
|
57
|
-
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
|
64
|
+
nan = F.fill(dtype(input_x), shape(input_x), np.nan)
|
|
65
|
+
inf = F.fill(dtype(input_x), shape(input_x), np.inf)
|
|
58
66
|
neg_x = less(input_x, 0.0)
|
|
59
67
|
nonpos_x = lessequal(input_x, 0.0)
|
|
60
68
|
log_x = log(input_x)
|
|
@@ -63,6 +71,14 @@ def log_generic(input_x):
|
|
|
63
71
|
return select(neg_x, nan, result)
|
|
64
72
|
|
|
65
73
|
|
|
74
|
+
def log_generic_with_check(x):
|
|
75
|
+
"""
|
|
76
|
+
log generic with input check
|
|
77
|
+
"""
|
|
78
|
+
_check_tensor(x, "the input of log_generic")
|
|
79
|
+
return log_generic(x)
|
|
80
|
+
|
|
81
|
+
|
|
66
82
|
def log1p_generic(x):
|
|
67
83
|
"""
|
|
68
84
|
Log1p ops on GPU device or when device_target == GPU.
|
|
@@ -315,7 +315,7 @@ class CheckTensor(PrimitiveWithInfer):
|
|
|
315
315
|
def __infer__(self, x, name):
|
|
316
316
|
src_type = x['dtype']
|
|
317
317
|
validator.check_subclass(
|
|
318
|
-
"input", src_type, [mstype.
|
|
318
|
+
"input", src_type, [mstype.tensor_type], name["value"])
|
|
319
319
|
|
|
320
320
|
out = {'shape': None,
|
|
321
321
|
'dtype': None,
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Bernoulli Distribution"""
|
|
16
16
|
from mindspore.common import dtype as mstype
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
20
21
|
from .distribution import Distribution
|
|
@@ -29,10 +30,10 @@ class Bernoulli(Distribution):
|
|
|
29
30
|
and the probability mass function as :math:`P(X = 0) = p, P(X = 1) = 1-p`.
|
|
30
31
|
|
|
31
32
|
Args:
|
|
32
|
-
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: None.
|
|
33
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
34
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
|
35
|
-
name (str): The name of the distribution. Default: 'Bernoulli'.
|
|
33
|
+
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: ``None`` .
|
|
34
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
35
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
|
|
36
|
+
name (str): The name of the distribution. Default: ``'Bernoulli'`` .
|
|
36
37
|
|
|
37
38
|
Note:
|
|
38
39
|
`probs` must be a proper probability (0 < p < 1).
|
|
@@ -151,7 +152,6 @@ class Bernoulli(Distribution):
|
|
|
151
152
|
self.cast = P.Cast()
|
|
152
153
|
self.const = P.ScalarToTensor()
|
|
153
154
|
self.floor = P.Floor()
|
|
154
|
-
self.fill = P.Fill()
|
|
155
155
|
self.less = P.Less()
|
|
156
156
|
self.shape = P.Shape()
|
|
157
157
|
self.select = P.Select()
|
|
@@ -200,8 +200,8 @@ class Bernoulli(Distribution):
|
|
|
200
200
|
MODE(B) = 1 if probs1 > 0.5 else = 0
|
|
201
201
|
"""
|
|
202
202
|
probs1 = self._check_param_type(probs1)
|
|
203
|
-
zeros =
|
|
204
|
-
ones =
|
|
203
|
+
zeros = F.fill(self.dtype, self.shape(probs1), 0.0)
|
|
204
|
+
ones = F.fill(self.dtype, self.shape(probs1), 1.0)
|
|
205
205
|
comp = self.less(0.5, probs1)
|
|
206
206
|
return self.select(comp, ones, zeros)
|
|
207
207
|
|
|
@@ -278,9 +278,9 @@ class Bernoulli(Distribution):
|
|
|
278
278
|
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
|
|
279
279
|
comp_zero = self.less(value, 0.0)
|
|
280
280
|
comp_one = self.less(value, 1.0)
|
|
281
|
-
zeros =
|
|
281
|
+
zeros = F.fill(self.parameter_type, self.shape(
|
|
282
282
|
broadcast_shape_tensor), 0.0)
|
|
283
|
-
ones =
|
|
283
|
+
ones = F.fill(self.parameter_type, self.shape(
|
|
284
284
|
broadcast_shape_tensor), 1.0)
|
|
285
285
|
less_than_zero = self.select(comp_zero, zeros, probs0)
|
|
286
286
|
return self.select(comp_one, less_than_zero, ones)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Beta Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
import mindspore.nn as nn
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -36,12 +37,12 @@ class Beta(Distribution):
|
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
39
|
concentration1 (int, float, list, numpy.ndarray, Tensor): The concentration1,
|
|
39
|
-
also know as alpha of the Beta distribution. Default: None.
|
|
40
|
+
also know as alpha of the Beta distribution. Default: ``None`` .
|
|
40
41
|
concentration0 (int, float, list, numpy.ndarray, Tensor): The concentration0, also know as
|
|
41
|
-
beta of the Beta distribution. Default: None.
|
|
42
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
43
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
44
|
-
name (str): The name of the distribution. Default: 'Beta'.
|
|
42
|
+
beta of the Beta distribution. Default: ``None`` .
|
|
43
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
44
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
45
|
+
name (str): The name of the distribution. Default: ``'Beta'`` .
|
|
45
46
|
|
|
46
47
|
Note:
|
|
47
48
|
- `concentration1` and `concentration0` must be greater than zero.
|
|
@@ -186,7 +187,6 @@ class Beta(Distribution):
|
|
|
186
187
|
self.pow = P.Pow()
|
|
187
188
|
self.squeeze = P.Squeeze(0)
|
|
188
189
|
self.cast = P.Cast()
|
|
189
|
-
self.fill = P.Fill()
|
|
190
190
|
self.shape = P.Shape()
|
|
191
191
|
self.select = P.Select()
|
|
192
192
|
self.logicaland = P.LogicalAnd()
|
|
@@ -266,7 +266,7 @@ class Beta(Distribution):
|
|
|
266
266
|
comp2 = self.greater(concentration0, 1.)
|
|
267
267
|
cond = self.logicaland(comp1, comp2)
|
|
268
268
|
batch_shape = self.shape(concentration1 + concentration0)
|
|
269
|
-
nan =
|
|
269
|
+
nan = F.fill(self.dtype, batch_shape, np.nan)
|
|
270
270
|
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
|
|
271
271
|
return self.select(cond, mode, nan)
|
|
272
272
|
|
|
@@ -379,7 +379,7 @@ class Beta(Distribution):
|
|
|
379
379
|
sample_shape = (1,)
|
|
380
380
|
else:
|
|
381
381
|
sample_shape = origin_shape
|
|
382
|
-
ones =
|
|
382
|
+
ones = F.fill(self.dtype, sample_shape, 1.0)
|
|
383
383
|
sample_gamma1 = C.gamma(
|
|
384
384
|
sample_shape, alpha=concentration1, beta=ones, seed=self.seed)
|
|
385
385
|
sample_gamma2 = C.gamma(
|
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""Categorical Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore import context
|
|
18
|
+
from mindspore.common import Tensor
|
|
18
19
|
from mindspore.ops import operations as P
|
|
20
|
+
from mindspore.ops import functional as F
|
|
19
21
|
from mindspore.ops import composite as C
|
|
20
22
|
from mindspore.ops.functional import stop_gradient
|
|
21
23
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -26,7 +28,7 @@ from mindspore.common import dtype as mstype
|
|
|
26
28
|
from .distribution import Distribution
|
|
27
29
|
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
|
28
30
|
check_distribution_name
|
|
29
|
-
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
|
31
|
+
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to, log_generic_with_check
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
class Categorical(Distribution):
|
|
@@ -36,10 +38,10 @@ class Categorical(Distribution):
|
|
|
36
38
|
and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`.
|
|
37
39
|
|
|
38
40
|
Args:
|
|
39
|
-
probs (Tensor, list, numpy.ndarray): Event probabilities. Default: None.
|
|
40
|
-
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
|
|
41
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
|
42
|
-
name (str): The name of the distribution. Default: Categorical.
|
|
41
|
+
probs (Tensor, list, numpy.ndarray): Event probabilities. Default: ``None`` .
|
|
42
|
+
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: ``None`` .
|
|
43
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` .
|
|
44
|
+
name (str): The name of the distribution. Default: ``Categorical`` .
|
|
43
45
|
|
|
44
46
|
Note:
|
|
45
47
|
`probs` must have rank at least 1, values are proper probabilities and sum to 1.
|
|
@@ -148,7 +150,6 @@ class Categorical(Distribution):
|
|
|
148
150
|
self.dtypeop = P.DType()
|
|
149
151
|
self.exp = exp_generic
|
|
150
152
|
self.expand_dim = P.ExpandDims()
|
|
151
|
-
self.fill = P.Fill()
|
|
152
153
|
self.gather = P.GatherNd()
|
|
153
154
|
self.greater = P.Greater()
|
|
154
155
|
self.issubclass = inner.IsSubClass()
|
|
@@ -156,6 +157,7 @@ class Categorical(Distribution):
|
|
|
156
157
|
# when the graph kernel mode is enable
|
|
157
158
|
# use Log directly as akg will handle the corner cases
|
|
158
159
|
self.log = P.Log() if context.get_context("enable_graph_kernel") else log_generic
|
|
160
|
+
self.log_with_check = P.Log() if context.get_context("enable_graph_kernel") else log_generic_with_check
|
|
159
161
|
self.log_softmax = P.LogSoftmax()
|
|
160
162
|
self.logicor = P.LogicalOr()
|
|
161
163
|
self.logicand = P.LogicalAnd()
|
|
@@ -253,8 +255,11 @@ class Categorical(Distribution):
|
|
|
253
255
|
probs_b = self._check_value(probs_b, 'probs_b')
|
|
254
256
|
probs_b = self.cast(probs_b, self.parameter_type)
|
|
255
257
|
probs_a = self._check_param_type(probs)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
+
if probs is None:
|
|
259
|
+
logits_a = self.log(probs_a)
|
|
260
|
+
else:
|
|
261
|
+
logits_a = self.log_with_check(probs_a)
|
|
262
|
+
logits_b = self.log_with_check(probs_b)
|
|
258
263
|
return self.squeeze(self.reduce_sum(
|
|
259
264
|
self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
|
|
260
265
|
|
|
@@ -287,7 +292,7 @@ class Categorical(Distribution):
|
|
|
287
292
|
# here we simulate casting to int but still keeping float dtype
|
|
288
293
|
value = self.cast(value, self.dtypeop(probs))
|
|
289
294
|
|
|
290
|
-
zeros =
|
|
295
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
291
296
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
|
292
297
|
self.greater(value, -1.))
|
|
293
298
|
value = self.select(between_zero_neone,
|
|
@@ -323,15 +328,18 @@ class Categorical(Distribution):
|
|
|
323
328
|
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
|
324
329
|
value_clipped = self.cast(value_clipped, self.index_type)
|
|
325
330
|
# create index from 0 ... NumOfLabels
|
|
326
|
-
|
|
331
|
+
start = Tensor(0, self.index_type)
|
|
332
|
+
end = self.cast(self.shape(value)[0], self.index_type)
|
|
333
|
+
delta = Tensor(1, self.index_type)
|
|
334
|
+
index = self.reshape(ops.range(start, end, delta), (-1, 1))
|
|
327
335
|
index = self.concat((index, value_clipped))
|
|
328
336
|
|
|
329
337
|
# index into logit_pmf, fill in out_of_bound places with -inf
|
|
330
338
|
# reshape into label shape N
|
|
331
339
|
logits_pmf = self.gather(self.reshape(
|
|
332
340
|
logits, (-1, num_classes)), index)
|
|
333
|
-
nan =
|
|
334
|
-
|
|
341
|
+
nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
|
|
342
|
+
self.nan)
|
|
335
343
|
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
|
|
336
344
|
ans = self.reshape(logits_pmf, label_shape)
|
|
337
345
|
if drop_dim:
|
|
@@ -351,7 +359,7 @@ class Categorical(Distribution):
|
|
|
351
359
|
|
|
352
360
|
value = self.cast(value, self.dtypeop(probs))
|
|
353
361
|
|
|
354
|
-
zeros =
|
|
362
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
355
363
|
between_zero_neone = self.logicand(
|
|
356
364
|
self.less(value, 0,), self.greater(value, -1.))
|
|
357
365
|
value = self.select(between_zero_neone, zeros, P.Floor()(value))
|
|
@@ -386,7 +394,7 @@ class Categorical(Distribution):
|
|
|
386
394
|
# reshape probs and fill less_than_zero places with 0
|
|
387
395
|
probs = self.reshape(probs, (-1, num_classes))
|
|
388
396
|
cdf = self.gather(self.cumsum(probs, 1), index)
|
|
389
|
-
zeros =
|
|
397
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
390
398
|
cdf = self.select(less_than_zero, zeros, cdf)
|
|
391
399
|
cdf = self.reshape(cdf, label_shape)
|
|
392
400
|
|
|
@@ -417,7 +425,7 @@ class Categorical(Distribution):
|
|
|
417
425
|
sample_shape = (1,)
|
|
418
426
|
|
|
419
427
|
probs_2d = self.reshape(probs, (-1, num_classes))
|
|
420
|
-
sample_tensor =
|
|
428
|
+
sample_tensor = F.fill(self.dtype, shape, 1.0)
|
|
421
429
|
sample_tensor = self.reshape(sample_tensor, (-1, 1))
|
|
422
430
|
num_sample = self.shape(sample_tensor)[0]
|
|
423
431
|
samples = C.multinomial(probs_2d, num_sample, seed=self.seed)
|
|
@@ -35,11 +35,11 @@ class Cauchy(Distribution):
|
|
|
35
35
|
where :math:`a, b` are loc and scale parameter respectively.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: None.
|
|
39
|
-
scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution. Default: None.
|
|
40
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
41
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
42
|
-
name (str): The name of the distribution. Default: 'Cauchy'.
|
|
38
|
+
loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: ``None`` .
|
|
39
|
+
scale (int, float, list, numpy.ndarray, Tensor): The scale of the Cauchy distribution. Default: ``None`` .
|
|
40
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
41
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
42
|
+
name (str): The name of the distribution. Default: ``'Cauchy'`` .
|
|
43
43
|
|
|
44
44
|
Note:
|
|
45
45
|
`scale` must be greater than zero.
|
|
@@ -170,7 +170,6 @@ class Cauchy(Distribution):
|
|
|
170
170
|
self.const = P.ScalarToTensor()
|
|
171
171
|
self.dtypeop = P.DType()
|
|
172
172
|
self.exp = exp_generic
|
|
173
|
-
self.fill = P.Fill()
|
|
174
173
|
self.less = P.Less()
|
|
175
174
|
self.log = log_generic
|
|
176
175
|
self.log1p = log1p_generic
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""basic"""
|
|
16
16
|
from mindspore import context
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.nn.cell import Cell
|
|
19
20
|
from mindspore.ops.primitive import constexpr
|
|
20
21
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -113,7 +114,6 @@ class Distribution(Cell):
|
|
|
113
114
|
# ops needed for the base class
|
|
114
115
|
self.cast_base = P.Cast()
|
|
115
116
|
self.dtype_base = P.DType()
|
|
116
|
-
self.fill_base = P.Fill()
|
|
117
117
|
self.sametypeshape_base = inner.SameTypeShape()
|
|
118
118
|
self.sq_base = P.Square()
|
|
119
119
|
self.sqrt_base = P.Sqrt()
|
|
@@ -194,11 +194,11 @@ class Distribution(Cell):
|
|
|
194
194
|
if broadcast_shape is None:
|
|
195
195
|
broadcast_shape = self.shape_base(arg)
|
|
196
196
|
common_dtype = self.dtype_base(arg)
|
|
197
|
-
broadcast_shape_tensor =
|
|
197
|
+
broadcast_shape_tensor = F.fill(
|
|
198
198
|
common_dtype, broadcast_shape, 1.0)
|
|
199
199
|
else:
|
|
200
200
|
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
|
201
|
-
broadcast_shape_tensor =
|
|
201
|
+
broadcast_shape_tensor = F.fill(
|
|
202
202
|
common_dtype, broadcast_shape, 1.0)
|
|
203
203
|
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
204
204
|
# check if the arguments have the same dtype
|
|
@@ -35,10 +35,10 @@ class Exponential(Distribution):
|
|
|
35
35
|
where :math:`\lambda` is the rate of the distribution.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
rate (int, float, list, numpy.ndarray, Tensor): The inverse scale. Default: None.
|
|
39
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
40
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
41
|
-
name (str): The name of the distribution. Default: 'Exponential'.
|
|
38
|
+
rate (int, float, list, numpy.ndarray, Tensor): The inverse scale. Default: ``None`` .
|
|
39
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
40
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
41
|
+
name (str): The name of the distribution. Default: ``'Exponential'`` .
|
|
42
42
|
|
|
43
43
|
Note:
|
|
44
44
|
`rate` must be strictly greater than 0.
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Gamma Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
import mindspore.nn as nn
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -38,12 +39,12 @@ class Gamma(Distribution):
|
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
41
|
concentration (int, float, list, numpy.ndarray, Tensor): The concentration,
|
|
41
|
-
also know as alpha of the Gamma distribution. Default: None.
|
|
42
|
+
also know as alpha of the Gamma distribution. Default: ``None`` .
|
|
42
43
|
rate (int, float, list, numpy.ndarray, Tensor): The rate, also know as
|
|
43
|
-
beta of the Gamma distribution. Default: None.
|
|
44
|
-
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
45
|
-
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
|
46
|
-
name (str): The name of the distribution. Default: 'Gamma'.
|
|
44
|
+
beta of the Gamma distribution. Default: ``None`` .
|
|
45
|
+
seed (int): The seed used in sampling. The global seed is used if it is None. Default: ``None`` .
|
|
46
|
+
dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.float32`` .
|
|
47
|
+
name (str): The name of the distribution. Default: ``'Gamma'`` .
|
|
47
48
|
|
|
48
49
|
Note:
|
|
49
50
|
`concentration` and `rate` must be greater than zero.
|
|
@@ -185,13 +186,12 @@ class Gamma(Distribution):
|
|
|
185
186
|
self.squeeze = P.Squeeze(0)
|
|
186
187
|
self.cast = P.Cast()
|
|
187
188
|
self.dtypeop = P.DType()
|
|
188
|
-
self.fill = P.Fill()
|
|
189
189
|
self.shape = P.Shape()
|
|
190
190
|
self.select = P.Select()
|
|
191
191
|
self.greater = P.Greater()
|
|
192
|
-
self.lgamma =
|
|
192
|
+
self.lgamma = P.Lgamma()
|
|
193
193
|
self.digamma = nn.DiGamma()
|
|
194
|
-
self.igamma =
|
|
194
|
+
self.igamma = P.Igamma()
|
|
195
195
|
|
|
196
196
|
def extend_repr(self):
|
|
197
197
|
"""Display instance object as string."""
|
|
@@ -265,8 +265,8 @@ class Gamma(Distribution):
|
|
|
265
265
|
"""
|
|
266
266
|
concentration, rate = self._check_param_type(concentration, rate)
|
|
267
267
|
mode = (concentration - 1.) / rate
|
|
268
|
-
nan =
|
|
269
|
-
|
|
268
|
+
nan = F.fill(self.dtypeop(concentration), self.shape(concentration),
|
|
269
|
+
np.nan)
|
|
270
270
|
comp = self.greater(concentration, 1.)
|
|
271
271
|
return self.select(comp, mode, nan)
|
|
272
272
|
|