mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
"""Utility functions to help distribution class."""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore import context
|
|
18
|
-
from mindspore
|
|
18
|
+
from mindspore import _checkparam as validator
|
|
19
19
|
from mindspore.common.tensor import Tensor
|
|
20
20
|
from mindspore.common.parameter import Parameter
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.ops import composite as C
|
|
23
22
|
from mindspore.ops import operations as P
|
|
24
|
-
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
|
23
|
+
from mindspore.ops.primitive import constexpr, _primexpr, PrimitiveWithInfer, prim_attr_register
|
|
24
|
+
import mindspore.ops as ops
|
|
25
25
|
import mindspore.nn as nn
|
|
26
26
|
|
|
27
27
|
|
|
@@ -214,7 +214,7 @@ def clamp_probs(probs):
|
|
|
214
214
|
clamp probs boundary
|
|
215
215
|
"""
|
|
216
216
|
eps = P.Eps()(probs)
|
|
217
|
-
return
|
|
217
|
+
return ops.clip_by_value(probs, eps, 1-eps)
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
def probs_to_logits(probs, is_binary=False):
|
|
@@ -230,48 +230,42 @@ def probs_to_logits(probs, is_binary=False):
|
|
|
230
230
|
return P.Log()(ps_clamped)
|
|
231
231
|
|
|
232
232
|
|
|
233
|
-
@constexpr
|
|
233
|
+
@constexpr(check=False)
|
|
234
234
|
def raise_none_error(name):
|
|
235
235
|
raise TypeError(f"the type {name} must be subclass of Tensor."
|
|
236
236
|
f" It can not be None since it is not specified during initialization.")
|
|
237
237
|
|
|
238
238
|
|
|
239
|
-
@
|
|
240
|
-
def raise_probs_logits_error():
|
|
241
|
-
raise TypeError(
|
|
242
|
-
"Either 'probs' or 'logits' must be specified, but not both.")
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
@constexpr
|
|
239
|
+
@_primexpr
|
|
246
240
|
def raise_broadcast_error(shape_a, shape_b):
|
|
247
241
|
raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
|
|
248
242
|
|
|
249
243
|
|
|
250
|
-
@constexpr
|
|
244
|
+
@constexpr(check=False)
|
|
251
245
|
def raise_not_impl_error(name):
|
|
252
246
|
raise ValueError(
|
|
253
247
|
f"{name} function must be implemented for non-linear transformation")
|
|
254
248
|
|
|
255
249
|
|
|
256
|
-
@constexpr
|
|
250
|
+
@constexpr(check=False)
|
|
257
251
|
def raise_not_implemented_util(func_name, obj, *args, **kwargs):
|
|
258
252
|
raise NotImplementedError(
|
|
259
253
|
f"{func_name} is not implemented for {obj} distribution.")
|
|
260
254
|
|
|
261
255
|
|
|
262
|
-
@constexpr
|
|
256
|
+
@constexpr(check=False)
|
|
263
257
|
def raise_type_error(name, cur_type, required_type):
|
|
264
258
|
raise TypeError(
|
|
265
259
|
f"For {name} , the type must be or be subclass of {required_type}, but got {cur_type}")
|
|
266
260
|
|
|
267
261
|
|
|
268
|
-
@constexpr
|
|
262
|
+
@constexpr(check=False)
|
|
269
263
|
def raise_not_defined(func_name, obj, *args, **kwargs):
|
|
270
264
|
raise ValueError(
|
|
271
265
|
f"{func_name} is undefined for {obj} distribution.")
|
|
272
266
|
|
|
273
267
|
|
|
274
|
-
@constexpr
|
|
268
|
+
@constexpr(check=False)
|
|
275
269
|
def check_distribution_name(name, expected_name):
|
|
276
270
|
if name is None:
|
|
277
271
|
raise ValueError(
|
|
@@ -16,16 +16,16 @@
|
|
|
16
16
|
from mindspore.common import dtype as mstype
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from .distribution import Distribution
|
|
21
21
|
from ._utils.utils import check_prob, check_distribution_name, clamp_probs
|
|
22
22
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class Bernoulli(Distribution):
|
|
26
|
-
"""
|
|
26
|
+
r"""
|
|
27
27
|
Bernoulli Distribution.
|
|
28
|
-
A Bernoulli Distribution is a discrete distribution with the range {0, 1}
|
|
28
|
+
A Bernoulli Distribution is a discrete distribution with the range :math:`\{0, 1\}`
|
|
29
29
|
and the probability mass function as :math:`P(X = 0) = p, P(X = 1) = 1-p`.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -149,7 +149,7 @@ class Bernoulli(Distribution):
|
|
|
149
149
|
self.log = log_generic
|
|
150
150
|
self.squeeze = P.Squeeze(0)
|
|
151
151
|
self.cast = P.Cast()
|
|
152
|
-
self.const = P.
|
|
152
|
+
self.const = P.ScalarToTensor()
|
|
153
153
|
self.floor = P.Floor()
|
|
154
154
|
self.fill = P.Fill()
|
|
155
155
|
self.less = P.Less()
|
|
@@ -324,8 +324,8 @@ class Bernoulli(Distribution):
|
|
|
324
324
|
sample_shape = (1,)
|
|
325
325
|
else:
|
|
326
326
|
sample_shape = origin_shape
|
|
327
|
-
l_zero = self.const(0.0)
|
|
328
|
-
h_one = self.const(1.0)
|
|
327
|
+
l_zero = self.const(0.0, mstype.float32)
|
|
328
|
+
h_one = self.const(1.0, mstype.float32)
|
|
329
329
|
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
|
|
330
330
|
sample = self.less(sample_uniform, probs1)
|
|
331
331
|
value = self.cast(sample, self.dtype)
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
19
|
import mindspore.nn as nn
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -18,19 +18,21 @@ from mindspore import context
|
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
19
|
from mindspore.ops import composite as C
|
|
20
20
|
from mindspore.ops.functional import stop_gradient
|
|
21
|
-
from mindspore.
|
|
21
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
23
|
+
import mindspore.ops as ops
|
|
22
24
|
import mindspore.nn as nn
|
|
23
25
|
from mindspore.common import dtype as mstype
|
|
24
26
|
from .distribution import Distribution
|
|
25
27
|
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
|
26
|
-
check_distribution_name
|
|
28
|
+
check_distribution_name
|
|
27
29
|
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class Categorical(Distribution):
|
|
31
|
-
"""
|
|
33
|
+
r"""
|
|
32
34
|
Categorical distribution.
|
|
33
|
-
A Categorical Distribution is a discrete distribution with the range {1, 2, ..., k}
|
|
35
|
+
A Categorical Distribution is a discrete distribution with the range :math:`\{1, 2, ..., k\}`
|
|
34
36
|
and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`.
|
|
35
37
|
|
|
36
38
|
Args:
|
|
@@ -140,7 +142,7 @@ class Categorical(Distribution):
|
|
|
140
142
|
self.argmax = P.ArgMaxWithValue(axis=-1)
|
|
141
143
|
self.broadcast = broadcast_to
|
|
142
144
|
self.cast = P.Cast()
|
|
143
|
-
self.clip_by_value =
|
|
145
|
+
self.clip_by_value = ops.clip_by_value
|
|
144
146
|
self.concat = P.Concat(-1)
|
|
145
147
|
self.cumsum = P.CumSum()
|
|
146
148
|
self.dtypeop = P.DType()
|
|
@@ -149,7 +151,7 @@ class Categorical(Distribution):
|
|
|
149
151
|
self.fill = P.Fill()
|
|
150
152
|
self.gather = P.GatherNd()
|
|
151
153
|
self.greater = P.Greater()
|
|
152
|
-
self.issubclass =
|
|
154
|
+
self.issubclass = inner.IsSubClass()
|
|
153
155
|
self.less = P.Less()
|
|
154
156
|
# when the graph kernel mode is enable
|
|
155
157
|
# use Log directly as akg will handle the corner cases
|
|
@@ -236,7 +238,7 @@ class Categorical(Distribution):
|
|
|
236
238
|
"""
|
|
237
239
|
probs = self._check_param_type(probs)
|
|
238
240
|
logits = self.log(probs)
|
|
239
|
-
return self.squeeze(
|
|
241
|
+
return self.squeeze(P.Neg()(self.reduce_sum(logits * probs, -1)))
|
|
240
242
|
|
|
241
243
|
def _kl_loss(self, dist, probs_b, probs=None):
|
|
242
244
|
"""
|
|
@@ -403,8 +405,6 @@ class Categorical(Distribution):
|
|
|
403
405
|
Returns:
|
|
404
406
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
|
405
407
|
"""
|
|
406
|
-
if self.device_target == 'Ascend':
|
|
407
|
-
raise_not_implemented_util('On d backend, sample', self.name)
|
|
408
408
|
shape = self.checktuple(shape, 'shape')
|
|
409
409
|
probs = self._check_param_type(probs)
|
|
410
410
|
num_classes = self.shape(probs)[-1]
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater_zero, check_distribution_name, raise_not_defined
|
|
@@ -26,13 +26,13 @@ from ._utils.custom_ops import exp_generic, log_generic, log1p_generic
|
|
|
26
26
|
class Cauchy(Distribution):
|
|
27
27
|
r"""
|
|
28
28
|
Cauchy distribution.
|
|
29
|
-
A Cauchy distributio is a continuous distribution with the range
|
|
29
|
+
A Cauchy distributio is a continuous distribution with the range of all real numbers
|
|
30
30
|
and the probability density function:
|
|
31
31
|
|
|
32
32
|
.. math::
|
|
33
33
|
f(x, a, b) = 1 / \pi b(1 - ((x - a)/b)^2),
|
|
34
34
|
|
|
35
|
-
where a
|
|
35
|
+
where :math:`a, b` are loc and scale parameter respectively.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: None.
|
|
@@ -167,7 +167,7 @@ class Cauchy(Distribution):
|
|
|
167
167
|
# ops needed for the class
|
|
168
168
|
self.atan = P.Atan()
|
|
169
169
|
self.cast = P.Cast()
|
|
170
|
-
self.const = P.
|
|
170
|
+
self.const = P.ScalarToTensor()
|
|
171
171
|
self.dtypeop = P.DType()
|
|
172
172
|
self.exp = exp_generic
|
|
173
173
|
self.fill = P.Fill()
|
|
@@ -308,7 +308,7 @@ class Cauchy(Distribution):
|
|
|
308
308
|
value = self.cast(value, self.dtype)
|
|
309
309
|
loc, scale = self._check_param_type(loc, scale)
|
|
310
310
|
z = (value - loc) / scale
|
|
311
|
-
return self.log1p(2. * self.atan(z) / np.pi) - self.log(self.const(2.))
|
|
311
|
+
return self.log1p(2. * self.atan(z) / np.pi) - self.log(self.const(2., mstype.float32))
|
|
312
312
|
|
|
313
313
|
def _quantile(self, p, loc=None, scale=None):
|
|
314
314
|
loc, scale = self._check_param_type(loc, scale)
|
|
@@ -338,7 +338,7 @@ class Cauchy(Distribution):
|
|
|
338
338
|
sum_square = self.sq(scale_a + scale_b)
|
|
339
339
|
square_diff = self.sq(loc_a - loc_b)
|
|
340
340
|
return self.log(sum_square + square_diff) - \
|
|
341
|
-
self.log(self.const(4.0)) - self.log(scale_a) - self.log(scale_b)
|
|
341
|
+
self.log(self.const(4.0, mstype.float32)) - self.log(scale_a) - self.log(scale_b)
|
|
342
342
|
|
|
343
343
|
def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
|
|
344
344
|
r"""
|
|
@@ -374,8 +374,8 @@ class Cauchy(Distribution):
|
|
|
374
374
|
sample_shape = (1,)
|
|
375
375
|
else:
|
|
376
376
|
sample_shape = origin_shape
|
|
377
|
-
l_zero = self.const(0.0)
|
|
378
|
-
h_one = self.const(1.0)
|
|
377
|
+
l_zero = self.const(0.0, mstype.float32)
|
|
378
|
+
h_one = self.const(1.0, mstype.float32)
|
|
379
379
|
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
|
|
380
380
|
sample = self._quantile(sample_uniform, loc, scale)
|
|
381
381
|
value = self.cast(sample, self.dtype)
|
|
@@ -17,7 +17,8 @@ from mindspore import context
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.nn.cell import Cell
|
|
19
19
|
from mindspore.ops.primitive import constexpr
|
|
20
|
-
from mindspore.
|
|
20
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
+
from mindspore import _checkparam as validator
|
|
21
22
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
|
22
23
|
raise_not_implemented_util
|
|
23
24
|
from ._utils.utils import CheckTuple, CheckTensor
|
|
@@ -101,7 +102,7 @@ class Distribution(Cell):
|
|
|
101
102
|
self.device_target = context.get_context('device_target')
|
|
102
103
|
self.checktuple = CheckTuple()
|
|
103
104
|
|
|
104
|
-
@constexpr
|
|
105
|
+
@constexpr(check=False)
|
|
105
106
|
def _check_tensor(x, name):
|
|
106
107
|
CheckTensor()(x, name)
|
|
107
108
|
return x
|
|
@@ -112,13 +113,17 @@ class Distribution(Cell):
|
|
|
112
113
|
# ops needed for the base class
|
|
113
114
|
self.cast_base = P.Cast()
|
|
114
115
|
self.dtype_base = P.DType()
|
|
115
|
-
self.exp_base = exp_generic
|
|
116
116
|
self.fill_base = P.Fill()
|
|
117
|
-
self.
|
|
118
|
-
self.sametypeshape_base = P.SameTypeShape()
|
|
117
|
+
self.sametypeshape_base = inner.SameTypeShape()
|
|
119
118
|
self.sq_base = P.Square()
|
|
120
119
|
self.sqrt_base = P.Sqrt()
|
|
121
120
|
self.shape_base = P.Shape()
|
|
121
|
+
if self.device_target != "Ascend":
|
|
122
|
+
self.log_base = P.Log()
|
|
123
|
+
self.exp_base = P.Exp()
|
|
124
|
+
else:
|
|
125
|
+
self.exp_base = exp_generic
|
|
126
|
+
self.log_base = log_generic
|
|
122
127
|
|
|
123
128
|
@property
|
|
124
129
|
def name(self):
|
|
@@ -427,7 +432,8 @@ class Distribution(Cell):
|
|
|
427
432
|
|
|
428
433
|
def prob(self, value, *args, **kwargs):
|
|
429
434
|
"""
|
|
430
|
-
Evaluate the probability (pdf or pmf) at given value.
|
|
435
|
+
Evaluate the probability (pdf or pmf) at given value. For a discrete distribution,
|
|
436
|
+
it is a probability mass function, while for a continuous distribution, it is probability density function.
|
|
431
437
|
|
|
432
438
|
Args:
|
|
433
439
|
value (Tensor): value to be evaluated.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -26,7 +26,7 @@ from ._utils.custom_ops import exp_generic, log_generic
|
|
|
26
26
|
class Exponential(Distribution):
|
|
27
27
|
r"""
|
|
28
28
|
Exponential Distribution.
|
|
29
|
-
An Exponential distributio is a continuous distribution with the range :math:`[0,
|
|
29
|
+
An Exponential distributio is a continuous distribution with the range :math:`[0, \inf)`
|
|
30
30
|
and the probability density function:
|
|
31
31
|
|
|
32
32
|
.. math::
|
|
@@ -159,7 +159,7 @@ class Exponential(Distribution):
|
|
|
159
159
|
self.log = log_generic
|
|
160
160
|
self.squeeze = P.Squeeze(0)
|
|
161
161
|
self.cast = P.Cast()
|
|
162
|
-
self.const = P.
|
|
162
|
+
self.const = P.ScalarToTensor()
|
|
163
163
|
self.dtypeop = P.DType()
|
|
164
164
|
self.fill = P.Fill()
|
|
165
165
|
self.less = P.Less()
|
|
@@ -340,8 +340,8 @@ class Exponential(Distribution):
|
|
|
340
340
|
sample_shape = (1,)
|
|
341
341
|
else:
|
|
342
342
|
sample_shape = origin_shape
|
|
343
|
-
minval = self.const(self.minval)
|
|
344
|
-
maxval = self.const(1.0)
|
|
343
|
+
minval = self.const(self.minval, mstype.float32)
|
|
344
|
+
maxval = self.const(1.0, mstype.float32)
|
|
345
345
|
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
|
346
346
|
sample = self.log(sample_uniform) / rate
|
|
347
347
|
value = self.cast(-sample, self.dtype)
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
19
|
import mindspore.nn as nn
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -27,14 +27,14 @@ from ._utils.custom_ops import log_generic
|
|
|
27
27
|
class Gamma(Distribution):
|
|
28
28
|
r"""
|
|
29
29
|
Gamma distribution.
|
|
30
|
-
A Gamma distributio is a continuous distribution with the range :math:`
|
|
30
|
+
A Gamma distributio is a continuous distribution with the range :math:`(0, \inf)`
|
|
31
31
|
and the probability density function:
|
|
32
32
|
|
|
33
33
|
.. math::
|
|
34
34
|
f(x, \alpha, \beta) = \beta^\alpha / \Gamma(\alpha) x^{\alpha - 1} \exp(-\beta x).
|
|
35
35
|
|
|
36
36
|
where :math:`G` is the Gamma function,
|
|
37
|
-
and :math:`\alpha
|
|
37
|
+
and :math:`\alpha` and :math:`\beta` are the concentration and the rate of the distribution respectively.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
40
|
concentration (int, float, list, numpy.ndarray, Tensor): The concentration,
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
"""Geometric Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
21
22
|
from .distribution import Distribution
|
|
22
23
|
from ._utils.utils import check_prob, check_distribution_name
|
|
@@ -157,11 +158,11 @@ class Geometric(Distribution):
|
|
|
157
158
|
self.log = log_generic
|
|
158
159
|
self.squeeze = P.Squeeze(0)
|
|
159
160
|
self.cast = P.Cast()
|
|
160
|
-
self.const = P.
|
|
161
|
+
self.const = P.ScalarToTensor()
|
|
161
162
|
self.dtypeop = P.DType()
|
|
162
163
|
self.fill = P.Fill()
|
|
163
164
|
self.floor = P.Floor()
|
|
164
|
-
self.issubclass =
|
|
165
|
+
self.issubclass = inner.IsSubClass()
|
|
165
166
|
self.less = P.Less()
|
|
166
167
|
self.pow = P.Pow()
|
|
167
168
|
self.select = P.Select()
|
|
@@ -324,8 +325,8 @@ class Geometric(Distribution):
|
|
|
324
325
|
sample_shape = (1,)
|
|
325
326
|
else:
|
|
326
327
|
sample_shape = origin_shape
|
|
327
|
-
minval = self.const(self.minval)
|
|
328
|
-
maxval = self.const(1.0)
|
|
328
|
+
minval = self.const(self.minval, mstype.float32)
|
|
329
|
+
maxval = self.const(1.0, mstype.float32)
|
|
329
330
|
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
|
330
331
|
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
|
|
331
332
|
value = self.cast(sample, self.dtype)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Gumbel Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
-
from mindspore
|
|
18
|
+
from mindspore import _checkparam as Validator
|
|
19
19
|
from mindspore.common import dtype as mstype
|
|
20
20
|
import mindspore.nn as nn
|
|
21
21
|
import mindspore.nn.probability.bijector as msb
|
|
@@ -28,13 +28,13 @@ from ._utils.custom_ops import exp_generic, log_generic
|
|
|
28
28
|
class Gumbel(TransformedDistribution):
|
|
29
29
|
r"""
|
|
30
30
|
Gumbel distribution.
|
|
31
|
-
A Gumbel distributio is a continuous distribution with the range
|
|
31
|
+
A Gumbel distributio is a continuous distribution with the range of all real numbers
|
|
32
32
|
and the probability density function:
|
|
33
33
|
|
|
34
34
|
.. math::
|
|
35
35
|
f(x, a, b) = 1 / b \exp(\exp(-(x - a) / b) - x),
|
|
36
36
|
|
|
37
|
-
where a
|
|
37
|
+
where :math:`a, b` are loc and scale parameter respectively.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
40
|
loc (int, float, list, numpy.ndarray, Tensor): The location of Gumbel distribution.
|
|
@@ -99,7 +99,7 @@ class Gumbel(TransformedDistribution):
|
|
|
99
99
|
|
|
100
100
|
# ops needed for the class
|
|
101
101
|
self.cast = P.Cast()
|
|
102
|
-
self.const = P.
|
|
102
|
+
self.const = P.ScalarToTensor()
|
|
103
103
|
self.exp = exp_generic
|
|
104
104
|
self.expm1 = P.Expm1()
|
|
105
105
|
self.fill = P.Fill()
|
|
@@ -175,7 +175,7 @@ class Gumbel(TransformedDistribution):
|
|
|
175
175
|
"""
|
|
176
176
|
scale = self.scale * \
|
|
177
177
|
self.fill(self.parameter_type, self.broadcast_shape, 1.0)
|
|
178
|
-
return scale * np.pi / self.sqrt(self.const(6.))
|
|
178
|
+
return scale * np.pi / self.sqrt(self.const(6., mstype.float32))
|
|
179
179
|
|
|
180
180
|
def _entropy(self):
|
|
181
181
|
r"""
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""HalfNormal Distribution"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
import numpy as np
|
|
19
|
+
from mindspore import ops
|
|
20
|
+
from mindspore.ops import operations as P
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
|
+
from mindspore.common import dtype as mstype
|
|
23
|
+
from mindspore.nn.probability.distribution import Distribution
|
|
24
|
+
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HalfNormal(Distribution):
|
|
28
|
+
r"""
|
|
29
|
+
HalfNormal distribution.
|
|
30
|
+
A HalfNormal distribution is a continuous distribution with the range :math:`[\mu, \inf)`
|
|
31
|
+
and the probability density function:
|
|
32
|
+
|
|
33
|
+
.. math::
|
|
34
|
+
f(x, \mu, \sigma) = 1 / \sigma\sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2).
|
|
35
|
+
|
|
36
|
+
where :math:`\mu, \sigma` are the mean and the standard deviation of the half normal distribution respectively.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
40
|
+
If this arg is None, then the mean of the distribution will be passed in runtime. Default: None.
|
|
41
|
+
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
|
|
42
|
+
If this arg is None, then the sd of the distribution will be passed in runtime. Default: None.
|
|
43
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
44
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
45
|
+
name (str, optional): The name of the distribution. Default: 'HalfNormal'.
|
|
46
|
+
|
|
47
|
+
Note:
|
|
48
|
+
- `sd` must be greater than zero.
|
|
49
|
+
- `dtype` must be a float type because HalfNormal distributions are continuous.
|
|
50
|
+
- If the arg `mean` or `sd` is passed in runtime, then it will be used as the parameter value.
|
|
51
|
+
Otherwise, the value passed in the constructor will be used.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: When sd <= 0.
|
|
55
|
+
TypeError: When the input `dtype` is not a subclass of float.
|
|
56
|
+
|
|
57
|
+
Supported Platforms:
|
|
58
|
+
``CPU``
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> import mindspore
|
|
62
|
+
>>> import mindspore.nn as nn
|
|
63
|
+
>>> from mindspore.nn.probability.distribution import HalfNormal
|
|
64
|
+
>>> from mindspore import Tensor
|
|
65
|
+
>>> # To initialize a HalfNormal distribution of the mean 3.0 and the standard deviation 4.0.
|
|
66
|
+
>>> n1 = HalfNormal(3.0, 4.0, dtype=mindspore.float32)
|
|
67
|
+
>>> # A HalfNormal distribution can be initialized without arguments.
|
|
68
|
+
>>> # In this case, `mean` and `sd` must be passed in through arguments.
|
|
69
|
+
>>> hn = HalfNormal(dtype=mindspore.float32)
|
|
70
|
+
>>> # Here are some tensors used below for testing
|
|
71
|
+
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
|
72
|
+
>>> mean_a = Tensor([2.0], dtype=mindspore.float32)
|
|
73
|
+
>>> sd_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
|
|
74
|
+
>>> mean_b = Tensor([1.0], dtype=mindspore.float32)
|
|
75
|
+
>>> sd_b = Tensor([1.0, 1.5, 2.5], dtype=mindspore.float32)
|
|
76
|
+
>>> ans = n1.log_prob(value)
|
|
77
|
+
>>> print(ans.shape)
|
|
78
|
+
(3,)
|
|
79
|
+
>>> # Evaluate with respect to the distribution b.
|
|
80
|
+
>>> ans = n1.log_prob(value, mean_b, sd_b)
|
|
81
|
+
>>> print(ans.shape)
|
|
82
|
+
(3,)
|
|
83
|
+
>>> # `mean` and `sd` must be passed in during function calls
|
|
84
|
+
>>> ans = hn.log_prob(value, mean_a, sd_a)
|
|
85
|
+
>>> print(ans.shape)
|
|
86
|
+
(3,)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self,
|
|
90
|
+
mean=None,
|
|
91
|
+
sd=None,
|
|
92
|
+
seed=None,
|
|
93
|
+
dtype=mstype.float32,
|
|
94
|
+
name="HalfNormal"):
|
|
95
|
+
"""
|
|
96
|
+
Constructor of HalfNormal.
|
|
97
|
+
"""
|
|
98
|
+
param = dict(locals())
|
|
99
|
+
param['param_dict'] = {'mean': mean, 'sd': sd}
|
|
100
|
+
valid_dtype = mstype.float_type
|
|
101
|
+
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
|
102
|
+
super(HalfNormal, self).__init__(seed, dtype, name, param)
|
|
103
|
+
|
|
104
|
+
self._mean_value = self._add_parameter(mean, 'mean')
|
|
105
|
+
self._sd_value = self._add_parameter(sd, 'sd')
|
|
106
|
+
if self._sd_value is not None:
|
|
107
|
+
check_greater_zero(self._sd_value, "Standard deviation")
|
|
108
|
+
|
|
109
|
+
self.exp = P.Exp()
|
|
110
|
+
self.cast = P.Cast()
|
|
111
|
+
self.const = ops.scalar_to_tensor(np.sqrt(2. / np.pi))
|
|
112
|
+
self.sq = P.Square()
|
|
113
|
+
self.type = dtype
|
|
114
|
+
|
|
115
|
+
def _prob(self, value, mean=None, sd=None):
|
|
116
|
+
r"""
|
|
117
|
+
Evaluate probability of the value of the HalfNormal distribution.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
value (Tensor): The value to be evaluated.
|
|
121
|
+
mean (Tensor, optional): The mean of the distribution. Default: self._mean_value.
|
|
122
|
+
sd (Tensor, optional): The standard deviation the distribution. Default: self._sd_value.
|
|
123
|
+
|
|
124
|
+
.. math::
|
|
125
|
+
P(x) = 1 / \sigma \sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2)
|
|
126
|
+
"""
|
|
127
|
+
value = self._check_value(value, 'value')
|
|
128
|
+
value = self.cast(value, self.dtype)
|
|
129
|
+
mean, sd = self._check_param_type(mean, sd)
|
|
130
|
+
|
|
131
|
+
coeff = self.const / sd
|
|
132
|
+
pdf = coeff * self.exp(-0.5 * self.sq((value - mean) / sd))
|
|
133
|
+
return pdf * self.cast(value >= 0, self.type)
|