mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/optim/lars.py
CHANGED
|
@@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
19
|
from mindspore.ops import composite as C
|
|
20
20
|
from mindspore.ops import functional as F
|
|
21
|
-
from mindspore
|
|
21
|
+
from mindspore import _checkparam as validator
|
|
22
22
|
from mindspore.common import Tensor, Parameter, dtype as mstype
|
|
23
23
|
from mindspore.common.api import jit
|
|
24
24
|
from mindspore.nn.optim.optimizer import _grad_scale, Optimizer
|
|
@@ -83,7 +83,7 @@ class LARS(Optimizer):
|
|
|
83
83
|
\end{array}
|
|
84
84
|
|
|
85
85
|
:math:`w` represents the network parameters, :math:`g` represents `gradients`,
|
|
86
|
-
:math:`t` represents the current step, :math:`\
|
|
86
|
+
:math:`t` represents the current step, :math:`\lambda` represents `weight_decay` in `optimizer`,
|
|
87
87
|
:math:`\gamma` represents `learning_rate` in `optimizer`, :math:`\eta` represents `coefficient`.
|
|
88
88
|
|
|
89
89
|
Args:
|
mindspore/nn/optim/lazyadam.py
CHANGED
|
@@ -23,8 +23,7 @@ from mindspore.ops import composite as C
|
|
|
23
23
|
from mindspore.ops import functional as F
|
|
24
24
|
from mindspore.common.parameter import Parameter
|
|
25
25
|
from mindspore.common.tensor import Tensor
|
|
26
|
-
from mindspore
|
|
27
|
-
from mindspore._checkparam import Rel
|
|
26
|
+
from mindspore import _checkparam as validator
|
|
28
27
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
29
28
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
|
30
29
|
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
|
|
@@ -86,6 +85,46 @@ def _run_opt_with_sparse_dist(opt, sparse_opt, push, pull, use_locking, use_nest
|
|
|
86
85
|
return success
|
|
87
86
|
|
|
88
87
|
|
|
88
|
+
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
89
|
+
"Tensor", "Tensor", "Tensor", "Tensor", "MapTensor", "MapTensor", "MapTensor", "MapTensor",
|
|
90
|
+
"Bool", "Bool", "Function", "Bool", "Function", "Bool")
|
|
91
|
+
def _run_map_tensor_opt_with_sparse_dist(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
|
92
|
+
beta2_power, beta1, beta2, eps, lr, gradient, params, m, v,
|
|
93
|
+
ps_parameter, cache_enable, distributed_opt, use_flag, distributed_sparse_opt,
|
|
94
|
+
use_sparse_flag):
|
|
95
|
+
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
|
|
96
|
+
success = True
|
|
97
|
+
indices, values = gradient.get_data()
|
|
98
|
+
if use_sparse_flag:
|
|
99
|
+
# PS Mode.
|
|
100
|
+
success = F.depend(success, distributed_sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
|
101
|
+
eps, values, indices))
|
|
102
|
+
else:
|
|
103
|
+
# PS Cache mode.
|
|
104
|
+
op_sqrt = P.Sqrt()
|
|
105
|
+
|
|
106
|
+
m_slice = m.get(indices)
|
|
107
|
+
v_slice = v.get(indices)
|
|
108
|
+
|
|
109
|
+
next_m = m_slice * beta1 + values * (1 - beta1)
|
|
110
|
+
next_v = v_slice * beta2 + values * values * (1 - beta2)
|
|
111
|
+
|
|
112
|
+
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
|
113
|
+
|
|
114
|
+
if use_nesterov:
|
|
115
|
+
m_temp = beta1 * next_m + values * (1 - beta1)
|
|
116
|
+
param_update = m_temp / (op_sqrt(next_v) + eps)
|
|
117
|
+
else:
|
|
118
|
+
param_update = next_m / (op_sqrt(next_v) + eps)
|
|
119
|
+
|
|
120
|
+
params_need_update = params.get(indices)
|
|
121
|
+
params.put(indices, params_need_update - lr_t * param_update)
|
|
122
|
+
m.put(indices, next_m)
|
|
123
|
+
v.put(indices, next_v)
|
|
124
|
+
|
|
125
|
+
return success
|
|
126
|
+
|
|
127
|
+
|
|
89
128
|
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
|
90
129
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool",
|
|
91
130
|
"Function", "Bool", "Function", "Bool")
|
|
@@ -212,8 +251,8 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
|
|
212
251
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
|
213
252
|
validator.check_value_type("eps", eps, [float], prim_name)
|
|
214
253
|
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
|
215
|
-
validator.check_float_range(beta1, 0.0, 1.0,
|
|
216
|
-
validator.check_float_range(beta2, 0.0, 1.0,
|
|
254
|
+
validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name)
|
|
255
|
+
validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name)
|
|
217
256
|
validator.check_positive_float(eps, "eps", prim_name)
|
|
218
257
|
validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
|
|
219
258
|
|
mindspore/nn/optim/momentum.py
CHANGED
|
@@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter
|
|
|
20
20
|
from mindspore.common.tensor import Tensor
|
|
21
21
|
from mindspore.common.api import jit
|
|
22
22
|
import mindspore.common.dtype as mstype
|
|
23
|
-
from mindspore
|
|
23
|
+
from mindspore import _checkparam as Validator
|
|
24
24
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
25
25
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
|
26
26
|
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
|
|
@@ -69,19 +69,20 @@ class Momentum(Optimizer):
|
|
|
69
69
|
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
|
|
70
70
|
|
|
71
71
|
.. math::
|
|
72
|
-
|
|
72
|
+
v_{t+1} = v_{t} \ast u + grad
|
|
73
73
|
|
|
74
74
|
If use_nesterov is True:
|
|
75
75
|
|
|
76
76
|
.. math::
|
|
77
|
-
|
|
77
|
+
p_{t+1} = p_{t} - (grad \ast lr + v_{t+1} \ast u \ast lr)
|
|
78
78
|
|
|
79
79
|
If use_nesterov is False:
|
|
80
80
|
|
|
81
81
|
.. math::
|
|
82
|
-
|
|
82
|
+
p_{t+1} = p_{t} - lr \ast v_{t+1}
|
|
83
83
|
|
|
84
|
-
Here: where grad
|
|
84
|
+
Here: where :math:`grad`, :math:`lr`, :math:`p`, :math:`v` and :math:`u` denote the gradients,
|
|
85
|
+
learning_rate, params, moments, and momentum respectively.
|
|
85
86
|
|
|
86
87
|
Note:
|
|
87
88
|
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
|
mindspore/nn/optim/optimizer.py
CHANGED
|
@@ -29,7 +29,7 @@ from mindspore.common.initializer import initializer
|
|
|
29
29
|
from mindspore.common import Tensor
|
|
30
30
|
from mindspore.common.sparse_tensor import RowTensorInner
|
|
31
31
|
import mindspore.common.dtype as mstype
|
|
32
|
-
from mindspore
|
|
32
|
+
from mindspore import _checkparam as validator
|
|
33
33
|
from mindspore import log as logger
|
|
34
34
|
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
|
35
35
|
from mindspore.parallel._ps_context import _is_ps_mode
|
|
@@ -181,6 +181,7 @@ class Optimizer(Cell):
|
|
|
181
181
|
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
|
|
182
182
|
|
|
183
183
|
self._init_opt_attrs(learning_rate, parameters, weight_decay)
|
|
184
|
+
self.add_flags(skip_auto_parallel_compile=True)
|
|
184
185
|
|
|
185
186
|
def _init_opt_attrs(self, learning_rate, parameters, weight_decay):
|
|
186
187
|
"""initialize optimizer attributions"""
|
|
@@ -718,6 +719,7 @@ class Optimizer(Cell):
|
|
|
718
719
|
|
|
719
720
|
Examples:
|
|
720
721
|
>>> from mindspore import nn
|
|
722
|
+
>>> # net = LeNet5()
|
|
721
723
|
>>> net = Net()
|
|
722
724
|
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
|
723
725
|
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
|
@@ -19,7 +19,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
19
19
|
from mindspore.common import Tensor
|
|
20
20
|
import mindspore.common.dtype as mstype
|
|
21
21
|
from mindspore.common.api import jit
|
|
22
|
-
from mindspore
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
23
23
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
24
24
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
|
25
25
|
|
|
@@ -158,7 +158,7 @@ class ProximalAdagrad(Optimizer):
|
|
|
158
158
|
ValueError: If `accum`, `l1`, `l2` or `weight_decay` is less than 0.
|
|
159
159
|
|
|
160
160
|
Supported Platforms:
|
|
161
|
-
``Ascend`` ``GPU``
|
|
161
|
+
``Ascend`` ``GPU``
|
|
162
162
|
|
|
163
163
|
Examples:
|
|
164
164
|
>>> import mindspore as ms
|
mindspore/nn/optim/rmsprop.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as validator
|
|
20
20
|
from mindspore.common.api import jit
|
|
21
21
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
22
22
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
mindspore/nn/optim/rprop.py
CHANGED
|
@@ -19,8 +19,7 @@ from mindspore import ops
|
|
|
19
19
|
from mindspore.ops import operations as P
|
|
20
20
|
import mindspore.common.dtype as mstype
|
|
21
21
|
from mindspore.common.api import jit
|
|
22
|
-
from mindspore
|
|
23
|
-
from mindspore._checkparam import Rel
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
24
23
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
25
24
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
|
26
25
|
|
|
@@ -37,12 +36,12 @@ class Rprop(Optimizer):
|
|
|
37
36
|
.. math::
|
|
38
37
|
\begin{gather*}
|
|
39
38
|
&\hspace{-10mm} \textbf{if} \: g_{t-1} g_t > 0 \\
|
|
40
|
-
&\hspace{25mm} \Delta_t \leftarrow \mathrm{min}(\Delta_{t-1} \eta_{+}, \Delta_{max})
|
|
41
|
-
&\hspace{0mm} \textbf{else if} \: g_{t-1} g_t < 0
|
|
42
|
-
&\hspace{25mm} \Delta_t \leftarrow \mathrm{max}(\Delta_{t-1} \eta_{-}, \Delta_{min})
|
|
43
|
-
&\hspace{-25mm} \textbf{else} \:
|
|
44
|
-
&\hspace{-5mm} \Delta_t \leftarrow \Delta_{t-1}
|
|
45
|
-
&\hspace{15mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t)
|
|
39
|
+
&\hspace{25mm} \Delta_t \leftarrow \mathrm{min}(\Delta_{t-1} \eta_{+}, \Delta_{max}) \\
|
|
40
|
+
&\hspace{0mm} \textbf{else if} \: g_{t-1} g_t < 0 \\
|
|
41
|
+
&\hspace{25mm} \Delta_t \leftarrow \mathrm{max}(\Delta_{t-1} \eta_{-}, \Delta_{min}) \\
|
|
42
|
+
&\hspace{-25mm} \textbf{else} \: \\
|
|
43
|
+
&\hspace{-5mm} \Delta_t \leftarrow \Delta_{t-1} \\
|
|
44
|
+
&\hspace{15mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t) \\
|
|
46
45
|
\end{gather*}
|
|
47
46
|
|
|
48
47
|
:math:`\Delta_{min/max}` represents the min/max step size, :math:`\eta_{+/-}` represents the factors of
|
|
@@ -175,7 +174,7 @@ class Rprop(Optimizer):
|
|
|
175
174
|
raise ValueError("For Rprop, maximal step size should not be less than minimal step size, "
|
|
176
175
|
"but got {} > {}.".format(step_sizes[0], step_sizes[1]))
|
|
177
176
|
|
|
178
|
-
validator.check_float_range(etas[0], 0.0, 1.0,
|
|
177
|
+
validator.check_float_range(etas[0], 0.0, 1.0, validator.INC_NEITHER, "etaminus", self.cls_name)
|
|
179
178
|
validator.check_value_type("etaplus", etas[1], [float], self.cls_name)
|
|
180
179
|
if etas[1] <= 1.0:
|
|
181
180
|
raise ValueError("For Rprop, etaplus must be greater than 1.0, but got etaplus {}.".format(etas[1]))
|
mindspore/nn/optim/sgd.py
CHANGED
|
@@ -20,15 +20,15 @@ from mindspore.common.parameter import Parameter
|
|
|
20
20
|
from mindspore.common.tensor import Tensor
|
|
21
21
|
from mindspore.common.api import jit
|
|
22
22
|
import mindspore.common.dtype as mstype
|
|
23
|
-
from mindspore
|
|
23
|
+
from mindspore import _checkparam as validator
|
|
24
24
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
25
25
|
from mindspore.nn.optim.optimizer import opt_init_args_register
|
|
26
26
|
|
|
27
27
|
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
@_sgd_opt.register("
|
|
31
|
-
def _tensor_run_opt_ext(
|
|
30
|
+
@_sgd_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Function")
|
|
31
|
+
def _tensor_run_opt_ext(momentum, learning_rate, gradient, weight, accum, stat, opt):
|
|
32
32
|
"""Apply sgd optimizer to the weight parameter using Tensor."""
|
|
33
33
|
success = True
|
|
34
34
|
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
|
|
@@ -76,7 +76,9 @@ class SGD(Optimizer):
|
|
|
76
76
|
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
|
77
77
|
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
|
78
78
|
|
|
79
|
-
- weight_decay:
|
|
79
|
+
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
|
80
|
+
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
|
81
|
+
decay must be float, dynamic weight decay is currently not supported.
|
|
80
82
|
|
|
81
83
|
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
|
82
84
|
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
|
@@ -164,7 +166,7 @@ class SGD(Optimizer):
|
|
|
164
166
|
|
|
165
167
|
if isinstance(momentum, float) and momentum < 0.0:
|
|
166
168
|
raise ValueError("For 'SGD', the argument 'momentum' must be at least 0.0, "
|
|
167
|
-
"but got {}".format(momentum))
|
|
169
|
+
"but got {}.".format(momentum))
|
|
168
170
|
|
|
169
171
|
if isinstance(dampening, int):
|
|
170
172
|
dampening = float(dampening)
|
|
@@ -177,9 +179,6 @@ class SGD(Optimizer):
|
|
|
177
179
|
"but got 'dampening' {}".format(dampening))
|
|
178
180
|
self.dampening = dampening
|
|
179
181
|
|
|
180
|
-
if isinstance(weight_decay, int):
|
|
181
|
-
weight_decay = float(weight_decay)
|
|
182
|
-
|
|
183
182
|
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
|
|
184
183
|
|
|
185
184
|
if nesterov and (momentum <= 0.0 or dampening != 0.0):
|
|
@@ -187,7 +186,14 @@ class SGD(Optimizer):
|
|
|
187
186
|
"equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
|
|
188
187
|
self.nesterov = nesterov
|
|
189
188
|
|
|
190
|
-
self.
|
|
189
|
+
if self.dynamic_weight_decay:
|
|
190
|
+
raise TypeError("For 'SGD', dynamic weight decay is currently not supported, the argument 'weight_decay' "
|
|
191
|
+
"or 'weight_decay' set in grouped 'params' must be float or int type.")
|
|
192
|
+
|
|
193
|
+
if hasattr(self, "group_weight_decay") and self.group_weight_decay:
|
|
194
|
+
self.opt = tuple(P.SGD(dampening, wd, nesterov) for wd in self.group_weight_decay)
|
|
195
|
+
else:
|
|
196
|
+
self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
|
|
191
197
|
|
|
192
198
|
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
|
193
199
|
self.accum = self._parameters.clone(prefix="accum", init='zeros')
|
|
@@ -203,9 +209,9 @@ class SGD(Optimizer):
|
|
|
203
209
|
gradients = self.scale_grad(gradients)
|
|
204
210
|
lr = self.get_lr()
|
|
205
211
|
if self.is_group_lr:
|
|
206
|
-
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.
|
|
207
|
-
lr, gradients, params, accum, stat)
|
|
212
|
+
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum),
|
|
213
|
+
lr, gradients, params, accum, stat, self.opt)
|
|
208
214
|
else:
|
|
209
|
-
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.
|
|
210
|
-
gradients, params, accum, stat)
|
|
215
|
+
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum, lr),
|
|
216
|
+
gradients, params, accum, stat, self.opt)
|
|
211
217
|
return success
|
mindspore/nn/optim/thor.py
CHANGED
|
@@ -25,7 +25,7 @@ import mindspore.ops as ops
|
|
|
25
25
|
import mindspore.nn as nn
|
|
26
26
|
import mindspore.common.dtype as mstype
|
|
27
27
|
import mindspore.log as logger
|
|
28
|
-
from mindspore
|
|
28
|
+
from mindspore import _checkparam as Validator
|
|
29
29
|
from mindspore.nn.optim.optimizer import Optimizer
|
|
30
30
|
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
|
31
31
|
from mindspore import context
|
|
@@ -254,11 +254,6 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
254
254
|
r"""
|
|
255
255
|
Updates gradients by second-order algorithm--THOR.
|
|
256
256
|
|
|
257
|
-
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
|
|
258
|
-
|
|
259
|
-
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation
|
|
260
|
-
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
|
|
261
|
-
|
|
262
257
|
The updating formulas are as follows,
|
|
263
258
|
|
|
264
259
|
.. math::
|
|
@@ -314,9 +309,9 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
314
309
|
|
|
315
310
|
enable_clip_grad (bool): Whether to clip the gradients. Default: False
|
|
316
311
|
|
|
317
|
-
frequency(int): The update interval of A/G and
|
|
318
|
-
A/G and
|
|
319
|
-
|
|
312
|
+
frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N
|
|
313
|
+
(N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps,
|
|
314
|
+
and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: 100.
|
|
320
315
|
|
|
321
316
|
Inputs:
|
|
322
317
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
@@ -341,8 +336,8 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
341
336
|
.. note::
|
|
342
337
|
Before running the following example, you need to customize the network Net and
|
|
343
338
|
dataset preparation function create_dataset. Refer to
|
|
344
|
-
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0
|
|
345
|
-
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0
|
|
339
|
+
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
|
|
340
|
+
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
|
|
346
341
|
|
|
347
342
|
>>> import mindspore as ms
|
|
348
343
|
>>> from mindspore.nn import thor
|
|
@@ -973,15 +968,15 @@ class ThorAscend(Optimizer):
|
|
|
973
968
|
matrix_g_combine_shape = self.shape(matrix_g_inv)
|
|
974
969
|
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
|
|
975
970
|
matrix_a_inv = self.reshape(matrix_a_inv,
|
|
976
|
-
(matrix_a_inv_shape[0]
|
|
977
|
-
matrix_a_inv_shape[0]
|
|
971
|
+
(matrix_a_inv_shape[0] // 16, 16,
|
|
972
|
+
matrix_a_inv_shape[0] // 16, 16))
|
|
978
973
|
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
|
|
979
974
|
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
|
|
980
975
|
|
|
981
976
|
matrix_g_inv_shape = self.shape(matrix_g_inv)
|
|
982
977
|
matrix_g_inv = self.reshape(matrix_g_inv,
|
|
983
|
-
(matrix_g_inv_shape[0]
|
|
984
|
-
matrix_g_inv_shape[0]
|
|
978
|
+
(matrix_g_inv_shape[0] // 16, 16,
|
|
979
|
+
matrix_g_inv_shape[0] // 16, 16))
|
|
985
980
|
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
|
|
986
981
|
|
|
987
982
|
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
|
|
@@ -19,7 +19,7 @@ from mindspore.ops import operations as P
|
|
|
19
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
|
-
from mindspore
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
23
23
|
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
|
|
24
24
|
from ..distribution import Distribution
|
|
25
25
|
from ..distribution import TransformedDistribution
|
|
@@ -28,9 +28,9 @@ from ..distribution import TransformedDistribution
|
|
|
28
28
|
class Bijector(Cell):
|
|
29
29
|
"""
|
|
30
30
|
Bijecotr class. A bijector perform a mapping from one distribution to the other via some function.
|
|
31
|
-
If X is a random variable following the original distribution,
|
|
32
|
-
and g(x) is the mapping function,
|
|
33
|
-
then Y = g(X) is the random variable following the transformed distribution.
|
|
31
|
+
If :math:`X` is a random variable following the original distribution,
|
|
32
|
+
and :math:`g(x)` is the mapping function,
|
|
33
|
+
then :math:`Y = g(X)` is the random variable following the transformed distribution.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: False.
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Invert Bijector"""
|
|
16
|
-
from mindspore
|
|
16
|
+
from mindspore import _checkparam as validator
|
|
17
17
|
from .bijector import Bijector
|
|
18
18
|
|
|
19
19
|
|
|
@@ -41,7 +41,7 @@ class Softplus(Bijector):
|
|
|
41
41
|
TypeError: When the dtype of the sharpness is not float.
|
|
42
42
|
|
|
43
43
|
Supported Platforms:
|
|
44
|
-
``Ascend`` ``GPU``
|
|
44
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
45
45
|
|
|
46
46
|
Examples:
|
|
47
47
|
>>> import mindspore
|
|
@@ -51,7 +51,7 @@ class Softplus(Bijector):
|
|
|
51
51
|
>>>
|
|
52
52
|
>>> # To initialize a Softplus bijector of sharpness 2.0.
|
|
53
53
|
>>> softplus = msb.Softplus(2.0)
|
|
54
|
-
>>> # To use a
|
|
54
|
+
>>> # To use a Softplus bijector in a network.
|
|
55
55
|
>>> value = Tensor([1, 2, 3], dtype=mindspore.float32)
|
|
56
56
|
>>> ans1 = softplus.forward(value)
|
|
57
57
|
>>> print(ans1.shape)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""dense_variational"""
|
|
16
16
|
from mindspore.ops import operations as P
|
|
17
|
-
from mindspore
|
|
17
|
+
from mindspore import _checkparam as Validator
|
|
18
18
|
from ...cell import Cell
|
|
19
19
|
from ...layer.activation import get_activation
|
|
20
20
|
from ..distribution.normal import Normal
|
|
@@ -29,7 +29,7 @@ class NormalPrior(Cell):
|
|
|
29
29
|
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
dtype (
|
|
32
|
+
dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
|
|
33
33
|
Default: mindspore.float32.
|
|
34
34
|
mean (int, float): Mean of normal distribution. Default: 0.
|
|
35
35
|
std (int, float): Standard deviation of normal distribution. Default: 0.1.
|
|
@@ -55,7 +55,7 @@ class NormalPosterior(Cell):
|
|
|
55
55
|
Args:
|
|
56
56
|
name (str): Name prepended to trainable parameter.
|
|
57
57
|
shape (list, tuple): Shape of the mean and standard deviation.
|
|
58
|
-
dtype (
|
|
58
|
+
dtype (mindspore.dtype): The argument is used to define the data type of the output tensor.
|
|
59
59
|
Default: mindspore.float32.
|
|
60
60
|
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
|
|
61
61
|
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
|
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
"""Utility functions to help distribution class."""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore import context
|
|
18
|
-
from mindspore
|
|
18
|
+
from mindspore import _checkparam as validator
|
|
19
19
|
from mindspore.common.tensor import Tensor
|
|
20
20
|
from mindspore.common.parameter import Parameter
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from mindspore.ops import operations as P
|
|
23
|
-
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
|
23
|
+
from mindspore.ops.primitive import constexpr, _primexpr, PrimitiveWithInfer, prim_attr_register
|
|
24
24
|
import mindspore.ops as ops
|
|
25
25
|
import mindspore.nn as nn
|
|
26
26
|
|
|
@@ -230,48 +230,42 @@ def probs_to_logits(probs, is_binary=False):
|
|
|
230
230
|
return P.Log()(ps_clamped)
|
|
231
231
|
|
|
232
232
|
|
|
233
|
-
@constexpr
|
|
233
|
+
@constexpr(check=False)
|
|
234
234
|
def raise_none_error(name):
|
|
235
235
|
raise TypeError(f"the type {name} must be subclass of Tensor."
|
|
236
236
|
f" It can not be None since it is not specified during initialization.")
|
|
237
237
|
|
|
238
238
|
|
|
239
|
-
@
|
|
240
|
-
def raise_probs_logits_error():
|
|
241
|
-
raise TypeError(
|
|
242
|
-
"Either 'probs' or 'logits' must be specified, but not both.")
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
@constexpr
|
|
239
|
+
@_primexpr
|
|
246
240
|
def raise_broadcast_error(shape_a, shape_b):
|
|
247
241
|
raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
|
|
248
242
|
|
|
249
243
|
|
|
250
|
-
@constexpr
|
|
244
|
+
@constexpr(check=False)
|
|
251
245
|
def raise_not_impl_error(name):
|
|
252
246
|
raise ValueError(
|
|
253
247
|
f"{name} function must be implemented for non-linear transformation")
|
|
254
248
|
|
|
255
249
|
|
|
256
|
-
@constexpr
|
|
250
|
+
@constexpr(check=False)
|
|
257
251
|
def raise_not_implemented_util(func_name, obj, *args, **kwargs):
|
|
258
252
|
raise NotImplementedError(
|
|
259
253
|
f"{func_name} is not implemented for {obj} distribution.")
|
|
260
254
|
|
|
261
255
|
|
|
262
|
-
@constexpr
|
|
256
|
+
@constexpr(check=False)
|
|
263
257
|
def raise_type_error(name, cur_type, required_type):
|
|
264
258
|
raise TypeError(
|
|
265
259
|
f"For {name} , the type must be or be subclass of {required_type}, but got {cur_type}")
|
|
266
260
|
|
|
267
261
|
|
|
268
|
-
@constexpr
|
|
262
|
+
@constexpr(check=False)
|
|
269
263
|
def raise_not_defined(func_name, obj, *args, **kwargs):
|
|
270
264
|
raise ValueError(
|
|
271
265
|
f"{func_name} is undefined for {obj} distribution.")
|
|
272
266
|
|
|
273
267
|
|
|
274
|
-
@constexpr
|
|
268
|
+
@constexpr(check=False)
|
|
275
269
|
def check_distribution_name(name, expected_name):
|
|
276
270
|
if name is None:
|
|
277
271
|
raise ValueError(
|
|
@@ -16,16 +16,16 @@
|
|
|
16
16
|
from mindspore.common import dtype as mstype
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from .distribution import Distribution
|
|
21
21
|
from ._utils.utils import check_prob, check_distribution_name, clamp_probs
|
|
22
22
|
from ._utils.custom_ops import exp_generic, log_generic
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class Bernoulli(Distribution):
|
|
26
|
-
"""
|
|
26
|
+
r"""
|
|
27
27
|
Bernoulli Distribution.
|
|
28
|
-
A Bernoulli Distribution is a discrete distribution with the range {0, 1}
|
|
28
|
+
A Bernoulli Distribution is a discrete distribution with the range :math:`\{0, 1\}`
|
|
29
29
|
and the probability mass function as :math:`P(X = 0) = p, P(X = 1) = 1-p`.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
19
|
import mindspore.nn as nn
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -19,20 +19,20 @@ from mindspore.ops import operations as P
|
|
|
19
19
|
from mindspore.ops import composite as C
|
|
20
20
|
from mindspore.ops.functional import stop_gradient
|
|
21
21
|
from mindspore.ops.operations import _inner_ops as inner
|
|
22
|
-
from mindspore
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
23
23
|
import mindspore.ops as ops
|
|
24
24
|
import mindspore.nn as nn
|
|
25
25
|
from mindspore.common import dtype as mstype
|
|
26
26
|
from .distribution import Distribution
|
|
27
27
|
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
|
28
|
-
check_distribution_name
|
|
28
|
+
check_distribution_name
|
|
29
29
|
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class Categorical(Distribution):
|
|
33
|
-
"""
|
|
33
|
+
r"""
|
|
34
34
|
Categorical distribution.
|
|
35
|
-
A Categorical Distribution is a discrete distribution with the range {1, 2, ..., k}
|
|
35
|
+
A Categorical Distribution is a discrete distribution with the range :math:`\{1, 2, ..., k\}`
|
|
36
36
|
and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`.
|
|
37
37
|
|
|
38
38
|
Args:
|
|
@@ -238,7 +238,7 @@ class Categorical(Distribution):
|
|
|
238
238
|
"""
|
|
239
239
|
probs = self._check_param_type(probs)
|
|
240
240
|
logits = self.log(probs)
|
|
241
|
-
return self.squeeze(
|
|
241
|
+
return self.squeeze(P.Neg()(self.reduce_sum(logits * probs, -1)))
|
|
242
242
|
|
|
243
243
|
def _kl_loss(self, dist, probs_b, probs=None):
|
|
244
244
|
"""
|
|
@@ -405,8 +405,6 @@ class Categorical(Distribution):
|
|
|
405
405
|
Returns:
|
|
406
406
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
|
407
407
|
"""
|
|
408
|
-
if self.device_target == 'Ascend':
|
|
409
|
-
raise_not_implemented_util('On d backend, sample', self.name)
|
|
410
408
|
shape = self.checktuple(shape, 'shape')
|
|
411
409
|
probs = self._check_param_type(probs)
|
|
412
410
|
num_classes = self.shape(probs)[-1]
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater_zero, check_distribution_name, raise_not_defined
|
|
@@ -26,13 +26,13 @@ from ._utils.custom_ops import exp_generic, log_generic, log1p_generic
|
|
|
26
26
|
class Cauchy(Distribution):
|
|
27
27
|
r"""
|
|
28
28
|
Cauchy distribution.
|
|
29
|
-
A Cauchy distributio is a continuous distribution with the range
|
|
29
|
+
A Cauchy distributio is a continuous distribution with the range of all real numbers
|
|
30
30
|
and the probability density function:
|
|
31
31
|
|
|
32
32
|
.. math::
|
|
33
33
|
f(x, a, b) = 1 / \pi b(1 - ((x - a)/b)^2),
|
|
34
34
|
|
|
35
|
-
where a
|
|
35
|
+
where :math:`a, b` are loc and scale parameter respectively.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
loc (int, float, list, numpy.ndarray, Tensor): The location of the Cauchy distribution. Default: None.
|
|
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|
|
18
18
|
from mindspore.nn.cell import Cell
|
|
19
19
|
from mindspore.ops.primitive import constexpr
|
|
20
20
|
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
-
from mindspore
|
|
21
|
+
from mindspore import _checkparam as validator
|
|
22
22
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
|
23
23
|
raise_not_implemented_util
|
|
24
24
|
from ._utils.utils import CheckTuple, CheckTensor
|
|
@@ -102,7 +102,7 @@ class Distribution(Cell):
|
|
|
102
102
|
self.device_target = context.get_context('device_target')
|
|
103
103
|
self.checktuple = CheckTuple()
|
|
104
104
|
|
|
105
|
-
@constexpr
|
|
105
|
+
@constexpr(check=False)
|
|
106
106
|
def _check_tensor(x, name):
|
|
107
107
|
CheckTensor()(x, name)
|
|
108
108
|
return x
|