mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -26,7 +26,7 @@ from ._utils.custom_ops import exp_generic, log_generic
|
|
|
26
26
|
class Exponential(Distribution):
|
|
27
27
|
r"""
|
|
28
28
|
Exponential Distribution.
|
|
29
|
-
An Exponential distributio is a continuous distribution with the range :math:`[0,
|
|
29
|
+
An Exponential distributio is a continuous distribution with the range :math:`[0, \inf)`
|
|
30
30
|
and the probability density function:
|
|
31
31
|
|
|
32
32
|
.. math::
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
19
|
import mindspore.nn as nn
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_greater_zero, check_distribution_name
|
|
@@ -27,14 +27,14 @@ from ._utils.custom_ops import log_generic
|
|
|
27
27
|
class Gamma(Distribution):
|
|
28
28
|
r"""
|
|
29
29
|
Gamma distribution.
|
|
30
|
-
A Gamma distributio is a continuous distribution with the range :math:`
|
|
30
|
+
A Gamma distributio is a continuous distribution with the range :math:`(0, \inf)`
|
|
31
31
|
and the probability density function:
|
|
32
32
|
|
|
33
33
|
.. math::
|
|
34
34
|
f(x, \alpha, \beta) = \beta^\alpha / \Gamma(\alpha) x^{\alpha - 1} \exp(-\beta x).
|
|
35
35
|
|
|
36
36
|
where :math:`G` is the Gamma function,
|
|
37
|
-
and :math:`\alpha
|
|
37
|
+
and :math:`\alpha` and :math:`\beta` are the concentration and the rate of the distribution respectively.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
40
|
concentration (int, float, list, numpy.ndarray, Tensor): The concentration,
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops.operations import _inner_ops as inner
|
|
19
19
|
from mindspore.ops import composite as C
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_prob, check_distribution_name
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Gumbel Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
-
from mindspore
|
|
18
|
+
from mindspore import _checkparam as Validator
|
|
19
19
|
from mindspore.common import dtype as mstype
|
|
20
20
|
import mindspore.nn as nn
|
|
21
21
|
import mindspore.nn.probability.bijector as msb
|
|
@@ -28,13 +28,13 @@ from ._utils.custom_ops import exp_generic, log_generic
|
|
|
28
28
|
class Gumbel(TransformedDistribution):
|
|
29
29
|
r"""
|
|
30
30
|
Gumbel distribution.
|
|
31
|
-
A Gumbel distributio is a continuous distribution with the range
|
|
31
|
+
A Gumbel distributio is a continuous distribution with the range of all real numbers
|
|
32
32
|
and the probability density function:
|
|
33
33
|
|
|
34
34
|
.. math::
|
|
35
35
|
f(x, a, b) = 1 / b \exp(\exp(-(x - a) / b) - x),
|
|
36
36
|
|
|
37
|
-
where a
|
|
37
|
+
where :math:`a, b` are loc and scale parameter respectively.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
40
|
loc (int, float, list, numpy.ndarray, Tensor): The location of Gumbel distribution.
|
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from __future__ import division
|
|
18
18
|
import numpy as np
|
|
19
|
+
from mindspore import ops
|
|
19
20
|
from mindspore.ops import operations as P
|
|
20
|
-
from mindspore
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
21
22
|
from mindspore.common import dtype as mstype
|
|
22
23
|
from mindspore.nn.probability.distribution import Distribution
|
|
23
24
|
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
|
@@ -35,16 +36,19 @@ class HalfNormal(Distribution):
|
|
|
35
36
|
where :math:`\mu, \sigma` are the mean and the standard deviation of the half normal distribution respectively.
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
|
-
mean (int, float, list, numpy.ndarray, Tensor): The mean of the distribution.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
40
|
+
If this arg is None, then the mean of the distribution will be passed in runtime. Default: None.
|
|
41
|
+
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
|
|
42
|
+
If this arg is None, then the sd of the distribution will be passed in runtime. Default: None.
|
|
43
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
44
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
45
|
+
name (str, optional): The name of the distribution. Default: 'HalfNormal'.
|
|
43
46
|
|
|
44
47
|
Note:
|
|
45
48
|
- `sd` must be greater than zero.
|
|
46
|
-
- `dist_spec_args` are `mean` and `sd`.
|
|
47
49
|
- `dtype` must be a float type because HalfNormal distributions are continuous.
|
|
50
|
+
- If the arg `mean` or `sd` is passed in runtime, then it will be used as the parameter value.
|
|
51
|
+
Otherwise, the value passed in the constructor will be used.
|
|
48
52
|
|
|
49
53
|
Raises:
|
|
50
54
|
ValueError: When sd <= 0.
|
|
@@ -104,18 +108,18 @@ class HalfNormal(Distribution):
|
|
|
104
108
|
|
|
105
109
|
self.exp = P.Exp()
|
|
106
110
|
self.cast = P.Cast()
|
|
107
|
-
self.const = np.sqrt(2. / np.pi)
|
|
111
|
+
self.const = ops.scalar_to_tensor(np.sqrt(2. / np.pi))
|
|
108
112
|
self.sq = P.Square()
|
|
109
113
|
self.type = dtype
|
|
110
114
|
|
|
111
115
|
def _prob(self, value, mean=None, sd=None):
|
|
112
116
|
r"""
|
|
113
|
-
Evaluate probability.
|
|
117
|
+
Evaluate probability of the value of the HalfNormal distribution.
|
|
114
118
|
|
|
115
119
|
Args:
|
|
116
120
|
value (Tensor): The value to be evaluated.
|
|
117
|
-
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
|
118
|
-
sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
|
|
121
|
+
mean (Tensor, optional): The mean of the distribution. Default: self._mean_value.
|
|
122
|
+
sd (Tensor, optional): The standard deviation the distribution. Default: self._sd_value.
|
|
119
123
|
|
|
120
124
|
.. math::
|
|
121
125
|
P(x) = 1 / \sigma \sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from __future__ import division
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from mindspore.nn.probability.distribution import Distribution
|
|
22
22
|
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
|
@@ -25,25 +25,28 @@ from mindspore.nn.probability.distribution._utils.utils import check_greater_zer
|
|
|
25
25
|
class Laplace(Distribution):
|
|
26
26
|
r"""
|
|
27
27
|
Laplace distribution.
|
|
28
|
-
A Laplace distribution is a continuous distribution with the range :math:`
|
|
28
|
+
A Laplace distribution is a continuous distribution with the range :math:`(-\inf, \inf)`
|
|
29
29
|
and the probability density function:
|
|
30
30
|
|
|
31
31
|
.. math::
|
|
32
|
-
f(x, \mu, b) = 1 / (2
|
|
32
|
+
f(x, \mu, b) = 1 / (2 * b) * \exp(-abs(x - \mu) / b).
|
|
33
33
|
|
|
34
34
|
where :math:`\mu, b` are the mean and the scale of the laplace distribution respectively.
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
|
-
mean (int, float, list, numpy.ndarray, Tensor): The mean of the distribution.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
38
|
+
If this arg is None, then the mean of the distribution will be passed in runtime. Default: None.
|
|
39
|
+
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The scale of the distribution.
|
|
40
|
+
If this arg is None, then the scale of the distribution will be passed in runtime. Default: None.
|
|
41
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
42
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
43
|
+
name (str, optional): The name of the distribution. Default: 'Laplace'.
|
|
42
44
|
|
|
43
45
|
Note:
|
|
44
46
|
- `sd` must be greater than zero.
|
|
45
|
-
- `dist_spec_args` are `mean` and `sd`.
|
|
46
47
|
- `dtype` must be a float type because Laplace distributions are continuous.
|
|
48
|
+
- If the arg `mean` or `sd` is passed in runtime, then it will be used as the parameter value.
|
|
49
|
+
Otherwise, the value passed in the constructor will be used.
|
|
47
50
|
|
|
48
51
|
Raises:
|
|
49
52
|
ValueError: When sd <= 0.
|
|
@@ -57,7 +60,7 @@ class Laplace(Distribution):
|
|
|
57
60
|
>>> import mindspore.nn as nn
|
|
58
61
|
>>> from mindspore.nn.probability.distribution import Laplace
|
|
59
62
|
>>> from mindspore import Tensor
|
|
60
|
-
>>> # To initialize a Laplace distribution of the mean 3.0 and the
|
|
63
|
+
>>> # To initialize a Laplace distribution of the mean 3.0 and the scale 4.0.
|
|
61
64
|
>>> n1 = Laplace(3.0, 4.0, dtype=mindspore.float32)
|
|
62
65
|
>>> # A Laplace distribution can be initialized without arguments.
|
|
63
66
|
>>> # In this case, `mean` and `sd` must be passed in through arguments.
|
|
@@ -107,12 +110,12 @@ class Laplace(Distribution):
|
|
|
107
110
|
|
|
108
111
|
def _log_prob(self, value, mean=None, sd=None):
|
|
109
112
|
r"""
|
|
110
|
-
Evaluate log probability.
|
|
113
|
+
Evaluate log probability of the laplace distribution.
|
|
111
114
|
|
|
112
115
|
Args:
|
|
113
116
|
value (Tensor): The value to be evaluated.
|
|
114
|
-
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
|
115
|
-
sd (Tensor): The
|
|
117
|
+
mean (Tensor, optional): The mean of the distribution. Default: self._mean_value.
|
|
118
|
+
sd (Tensor, optional): The scale the distribution. Default: self._sd_value.
|
|
116
119
|
|
|
117
120
|
.. math::
|
|
118
121
|
L(x) = -1* \abs{\frac{x - \mu}{\sigma}} - \log(2. * \sigma))
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater_zero
|
|
@@ -32,7 +32,7 @@ class Logistic(Distribution):
|
|
|
32
32
|
.. math::
|
|
33
33
|
f(x, a, b) = 1 / b \exp(\exp(-(x - a) / b) - x).
|
|
34
34
|
|
|
35
|
-
where a
|
|
35
|
+
where :math:`a, b` are loc and scale parameter respectively.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
loc (float, list, numpy.ndarray, Tensor): The location of the Logistic distribution. Default: None.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from mindspore.common import Tensor
|
|
22
22
|
from .distribution import Distribution
|
|
@@ -17,7 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
19
|
import mindspore.nn as nn
|
|
20
|
-
from mindspore
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from .distribution import Distribution
|
|
23
23
|
from ._utils.utils import check_greater_zero
|
|
@@ -18,7 +18,7 @@ from __future__ import division
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import mindspore.nn as nn
|
|
20
20
|
from mindspore.ops import operations as P
|
|
21
|
-
from mindspore
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
22
|
from mindspore.common import dtype as mstype
|
|
23
23
|
from mindspore.nn.probability.distribution import Distribution
|
|
24
24
|
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
|
@@ -27,28 +27,33 @@ from mindspore.nn.probability.distribution._utils.utils import check_greater_zer
|
|
|
27
27
|
class StudentT(Distribution):
|
|
28
28
|
r"""
|
|
29
29
|
StudentT distribution.
|
|
30
|
-
A StudentT distribution is a continuous distribution with the range :math:`
|
|
30
|
+
A StudentT distribution is a continuous distribution with the range :math:`(-\inf, \inf)`
|
|
31
31
|
and the probability density function:
|
|
32
32
|
|
|
33
33
|
.. math::
|
|
34
34
|
f(x, \nu, \mu, \sigma) = (1 + y^2 / \nu)^{(-0.5*(\nu + 1))} / Z
|
|
35
35
|
|
|
36
|
-
where :math:`y = (x
|
|
37
|
-
|
|
36
|
+
where :math:`y = (x - \mu)/ \sigma`,
|
|
37
|
+
:math:`Z = abs(\sigma) * \sqrt{(\nu * \pi)} * \Gamma(0.5 * \nu) / \Gamma(0.5 * (\nu + 1))`,
|
|
38
|
+
:math:`\nu, \mu, \sigma` are the degrees of freedom , mean and sd of the laplace distribution respectively.
|
|
38
39
|
|
|
39
40
|
Args:
|
|
40
|
-
df (int, float, list, numpy.ndarray, Tensor): The degrees of freedom.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
41
|
+
df (Union[int, float, list, numpy.ndarray, Tensor], optional): The degrees of freedom.
|
|
42
|
+
If this arg is None, then the df of the distribution will be passed in runtime. Default: None.
|
|
43
|
+
mean (Union[int, float, list, numpy.ndarray, Tensor], optional): The mean of the distribution.
|
|
44
|
+
If this arg is None, then the df of the distribution will be passed in runtime. Default: None.
|
|
45
|
+
sd (Union[int, float, list, numpy.ndarray, Tensor], optional): The standard deviation of the distribution.
|
|
46
|
+
If this arg is None, then the sd of the distribution will be passed in runtime. Default: None.
|
|
47
|
+
seed (int, optional): The seed used in sampling. The global seed is used if it is None. Default: None.
|
|
48
|
+
dtype (mindspore.dtype, optional): The type of the event samples. Default: mstype.float32.
|
|
49
|
+
name (str, optional): The name of the distribution. Default: 'StudentT'.
|
|
46
50
|
|
|
47
51
|
Note:
|
|
48
52
|
- `df` must be greater than zero.
|
|
49
53
|
- `sd` must be greater than zero.
|
|
50
|
-
- `dist_spec_args` are `mean` and `sd`.
|
|
51
54
|
- `dtype` must be a float type because StudentT distributions are continuous.
|
|
55
|
+
- If the arg `df`, `mean` or `sd` is passed in runtime, then it will be used as the parameter value.
|
|
56
|
+
Otherwise, the value passed in the constructor will be used.
|
|
52
57
|
|
|
53
58
|
Raises:
|
|
54
59
|
ValueError: When df <= 0.
|
|
@@ -122,13 +127,13 @@ class StudentT(Distribution):
|
|
|
122
127
|
|
|
123
128
|
def _log_prob(self, value, df=None, mean=None, sd=None):
|
|
124
129
|
r"""
|
|
125
|
-
Evaluate log probability.
|
|
130
|
+
Evaluate log probability of the value of the StudentT distribution.
|
|
126
131
|
|
|
127
132
|
Args:
|
|
128
133
|
value (Tensor): The value to be evaluated.
|
|
129
|
-
df (Tensor): The degrees of freedom of the distribution. Default: self._df_value.
|
|
130
|
-
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
|
131
|
-
sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
|
|
134
|
+
df (Tensor, optional): The degrees of freedom of the distribution. Default: self._df_value.
|
|
135
|
+
mean (Tensor, optional): The mean of the distribution. Default: self._mean_value.
|
|
136
|
+
sd (Tensor, optional): The standard deviation the distribution. Default: self._sd_value.
|
|
132
137
|
|
|
133
138
|
.. math::
|
|
134
139
|
L(x) = -0.5 * (\nu + 1.) * \log((x - \mu) / \sigma + 1.)) + \log(\sqrt(\pi * \mu * \sigma^2))
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Transformed Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
|
-
from mindspore
|
|
17
|
+
from mindspore import _checkparam as validator
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
19
|
from mindspore.common import dtype as mstype
|
|
20
20
|
import mindspore.nn as nn
|
|
@@ -28,9 +28,9 @@ class TransformedDistribution(Distribution):
|
|
|
28
28
|
Transformed Distribution.
|
|
29
29
|
This class contains a bijector and a distribution and transforms the original distribution
|
|
30
30
|
to a new distribution through the operation defined by the bijector.
|
|
31
|
-
If X is an random variable following the underying distribution,
|
|
32
|
-
and g(x) is a function represented by the bijector,
|
|
33
|
-
then Y = g(X) is a random variable following the transformed distribution.
|
|
31
|
+
If :math:`X` is an random variable following the underying distribution,
|
|
32
|
+
and :math:`g(x)` is a function represented by the bijector,
|
|
33
|
+
then :math:`Y = g(X)` is a random variable following the transformed distribution.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
bijector (Bijector): The transformation to perform.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
18
|
from mindspore.ops import composite as C
|
|
19
|
-
from mindspore
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from .distribution import Distribution
|
|
22
22
|
from ._utils.utils import check_greater, check_distribution_name
|
|
@@ -32,7 +32,7 @@ class Uniform(Distribution):
|
|
|
32
32
|
.. math::
|
|
33
33
|
f(x, a, b) = 1 / (b - a),
|
|
34
34
|
|
|
35
|
-
where a
|
|
35
|
+
where :math:`a, b` are the lower and upper bound respectively.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
low (int, float, list, numpy.ndarray, Tensor): The lower bound of the distribution. Default: None.
|
|
@@ -19,7 +19,7 @@ from __future__ import absolute_import
|
|
|
19
19
|
|
|
20
20
|
from mindspore.nn.cell import Cell
|
|
21
21
|
from mindspore.ops.operations import _rl_inner_ops as rl_ops
|
|
22
|
-
from mindspore
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
23
23
|
from mindspore.common import dtype as mstype
|
|
24
24
|
|
|
25
25
|
|
|
@@ -59,9 +59,9 @@ class TensorsQueue(Cell):
|
|
|
59
59
|
"""Initialize TensorsQueue"""
|
|
60
60
|
super(TensorsQueue, self).__init__()
|
|
61
61
|
Validator.check_subclass("dtype", dtype, mstype.number_type + (mstype.bool_,), self.cls_name)
|
|
62
|
-
Validator.check_int(size, 0,
|
|
62
|
+
Validator.check_int(size, 0, Validator.GE, "size", self.cls_name)
|
|
63
63
|
elements_num = len(shapes)
|
|
64
|
-
Validator.check_int(elements_num, 1,
|
|
64
|
+
Validator.check_int(elements_num, 1, Validator.GE, "len(shapes)", self.cls_name)
|
|
65
65
|
self.handle_ = rl_ops.TensorsQueueCreate(dtype, shapes, size, name)()
|
|
66
66
|
self.tensors_q_put = rl_ops.TensorsQueuePut(dtype, shapes)
|
|
67
67
|
self.tensors_q_get = rl_ops.TensorsQueueGet(dtype, shapes)
|
|
@@ -19,7 +19,7 @@ from __future__ import absolute_import
|
|
|
19
19
|
|
|
20
20
|
from mindspore.nn.cell import Cell
|
|
21
21
|
from mindspore.ops.operations import _tensor_array as ta
|
|
22
|
-
from mindspore
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
23
23
|
from mindspore.common import dtype as mstype
|
|
24
24
|
|
|
25
25
|
|
|
@@ -62,7 +62,7 @@ class TensorArray(Cell):
|
|
|
62
62
|
"""Initialize TensorArray"""
|
|
63
63
|
super(TensorArray, self).__init__()
|
|
64
64
|
Validator.check_subclass("dtype", dtype, mstype.number_type + (mstype.bool_,), self.cls_name)
|
|
65
|
-
Validator.check_int(size, 0,
|
|
65
|
+
Validator.check_int(size, 0, Validator.GE, "size", self.cls_name)
|
|
66
66
|
self.handle_ = ta.TensorArray(dtype, element_shape, dynamic_size, size, name)()
|
|
67
67
|
self.tensor_array_write = ta.TensorArrayWrite()
|
|
68
68
|
self.tensor_array_read = ta.TensorArrayRead(dtype, element_shape)
|
mindspore/nn/sparse/sparse.py
CHANGED
|
@@ -45,7 +45,7 @@ class SparseToDense(Cell):
|
|
|
45
45
|
TypeError: If `sparse_tensor.shape` is not a tuple.
|
|
46
46
|
|
|
47
47
|
Supported Platforms:
|
|
48
|
-
``CPU``
|
|
48
|
+
``GPU`` ``CPU``
|
|
49
49
|
|
|
50
50
|
Examples:
|
|
51
51
|
>>> import mindspore as ms
|
|
@@ -118,7 +118,7 @@ class SparseTensorDenseMatmul(Cell):
|
|
|
118
118
|
and shape of `dense` don't meet the parameter description.
|
|
119
119
|
|
|
120
120
|
Supported Platforms:
|
|
121
|
-
``CPU``
|
|
121
|
+
``GPU`` ``CPU``
|
|
122
122
|
|
|
123
123
|
Examples:
|
|
124
124
|
>>> import mindspore as ms
|
|
@@ -23,11 +23,11 @@ from mindspore import log as logger
|
|
|
23
23
|
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
|
|
24
24
|
_get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
|
|
25
25
|
from mindspore.context import ParallelMode
|
|
26
|
-
from mindspore
|
|
26
|
+
from mindspore import _checkparam as validator
|
|
27
27
|
from mindspore import ops, nn
|
|
28
28
|
from mindspore.common import dtype as mstype
|
|
29
29
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
30
|
-
from mindspore.ops.primitive import
|
|
30
|
+
from mindspore.ops.primitive import _primexpr
|
|
31
31
|
from mindspore.ops import composite as C
|
|
32
32
|
from mindspore.ops import functional as F
|
|
33
33
|
from mindspore.ops import operations as P
|
|
@@ -110,7 +110,7 @@ class WithLossCell(Cell):
|
|
|
110
110
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
111
111
|
self._backbone = backbone
|
|
112
112
|
self._loss_fn = loss_fn
|
|
113
|
-
if backbone.jit_config_dict:
|
|
113
|
+
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
|
114
114
|
self._jit_config_dict = backbone.jit_config_dict
|
|
115
115
|
|
|
116
116
|
def construct(self, data, label):
|
|
@@ -147,7 +147,7 @@ class WithGradCell(Cell):
|
|
|
147
147
|
output value. Default: None.
|
|
148
148
|
|
|
149
149
|
Inputs:
|
|
150
|
-
-
|
|
150
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
151
151
|
|
|
152
152
|
Outputs:
|
|
153
153
|
list, a list of Tensors with identical shapes as trainable weights.
|
|
@@ -182,6 +182,8 @@ class WithGradCell(Cell):
|
|
|
182
182
|
else:
|
|
183
183
|
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
|
|
184
184
|
self.network_with_loss.set_train()
|
|
185
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
186
|
+
self._jit_config_dict = network.jit_config_dict
|
|
185
187
|
|
|
186
188
|
def construct(self, *inputs):
|
|
187
189
|
weights = self.weights
|
|
@@ -216,8 +218,8 @@ class ForwardValueAndGrad(Cell):
|
|
|
216
218
|
the input parameter.
|
|
217
219
|
|
|
218
220
|
Inputs:
|
|
219
|
-
-
|
|
220
|
-
- **
|
|
221
|
+
- **\*inputs** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
|
|
222
|
+
- **sens** - A sensitivity (gradient with respect to output) as the input of backpropagation.
|
|
221
223
|
If network has single output, the sens is a tensor.
|
|
222
224
|
If network has multiple outputs, the sens is the tuple(tensor).
|
|
223
225
|
|
|
@@ -282,6 +284,8 @@ class ForwardValueAndGrad(Cell):
|
|
|
282
284
|
self.get_by_list = get_by_list
|
|
283
285
|
self.sens_param = sens_param
|
|
284
286
|
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
|
287
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
288
|
+
self._jit_config_dict = network.jit_config_dict
|
|
285
289
|
|
|
286
290
|
def construct(self, *inputs):
|
|
287
291
|
grad_inputs = inputs
|
|
@@ -309,7 +313,7 @@ class TrainOneStepCell(Cell):
|
|
|
309
313
|
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
310
314
|
|
|
311
315
|
Inputs:
|
|
312
|
-
-
|
|
316
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
313
317
|
|
|
314
318
|
Outputs:
|
|
315
319
|
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
|
|
@@ -375,6 +379,8 @@ class TrainOneStepCell(Cell):
|
|
|
375
379
|
create_group(server_group_name, group_list[current_index])
|
|
376
380
|
group = server_group_name
|
|
377
381
|
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
|
|
382
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
383
|
+
self._jit_config_dict = network.jit_config_dict
|
|
378
384
|
|
|
379
385
|
def construct(self, *inputs):
|
|
380
386
|
loss = self.network(*inputs)
|
|
@@ -453,18 +459,19 @@ class _VirtualDatasetCell(Cell):
|
|
|
453
459
|
super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
|
|
454
460
|
self._backbone = backbone
|
|
455
461
|
self._virtual_dataset = _VirtualDataset()
|
|
462
|
+
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
|
463
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
456
464
|
|
|
457
465
|
def construct(self, *inputs):
|
|
458
466
|
output = self._virtual_dataset(*inputs)
|
|
459
467
|
return self._backbone(*output)
|
|
460
468
|
|
|
461
469
|
|
|
462
|
-
@
|
|
470
|
+
@_primexpr
|
|
463
471
|
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
|
|
464
472
|
if input_shape[0] % micro_size != 0:
|
|
465
473
|
raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
|
|
466
474
|
f"divided by micro size({micro_size})")
|
|
467
|
-
return True
|
|
468
475
|
|
|
469
476
|
|
|
470
477
|
class _MicroBatch(Cell):
|
|
@@ -545,6 +552,8 @@ class MicroBatchInterleaved(Cell):
|
|
|
545
552
|
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
|
546
553
|
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
|
547
554
|
self.interleave_inputs.append(interleave_data)
|
|
555
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
556
|
+
self._jit_config_dict = network.jit_config_dict
|
|
548
557
|
|
|
549
558
|
def construct(self, *inputs):
|
|
550
559
|
output = 0.0
|
|
@@ -583,6 +592,8 @@ class PipelineCell(Cell):
|
|
|
583
592
|
self.micro_inputs.append(micro_input)
|
|
584
593
|
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
|
585
594
|
self.add_list.append(self.add)
|
|
595
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
596
|
+
self._jit_config_dict = network.jit_config_dict
|
|
586
597
|
|
|
587
598
|
def construct(self, *inputs):
|
|
588
599
|
ret = None
|
|
@@ -611,6 +622,8 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
|
|
|
611
622
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
|
612
623
|
self.hyper_map = ops.HyperMap()
|
|
613
624
|
self.opt_shard = _get_enable_parallel_optimizer()
|
|
625
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
626
|
+
self._jit_config_dict = network.jit_config_dict
|
|
614
627
|
|
|
615
628
|
def construct(self, *inputs):
|
|
616
629
|
weights = self.weights
|
|
@@ -652,6 +665,8 @@ class VirtualDatasetCellTriple(Cell):
|
|
|
652
665
|
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
|
|
653
666
|
logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
|
|
654
667
|
self._backbone = backbone
|
|
668
|
+
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
|
669
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
655
670
|
|
|
656
671
|
def construct(self, a, b, c):
|
|
657
672
|
return self._backbone(a, b, c)
|
|
@@ -694,6 +709,8 @@ class WithEvalCell(Cell):
|
|
|
694
709
|
self._network = network
|
|
695
710
|
self._loss_fn = loss_fn
|
|
696
711
|
self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
|
|
712
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
713
|
+
self._jit_config_dict = network.jit_config_dict
|
|
697
714
|
|
|
698
715
|
def construct(self, data, label):
|
|
699
716
|
outputs = self._network(data)
|
|
@@ -717,7 +734,7 @@ class ParameterUpdate(Cell):
|
|
|
717
734
|
- **x** (Tensor) - A tensor whose shape and type are the same with `param`.
|
|
718
735
|
|
|
719
736
|
Outputs:
|
|
720
|
-
Tensor, the
|
|
737
|
+
Tensor, the updated value.
|
|
721
738
|
|
|
722
739
|
Raises:
|
|
723
740
|
KeyError: If parameter with the specified name does not exist.
|
|
@@ -315,11 +315,11 @@ class DistributedGradReducer(Cell):
|
|
|
315
315
|
|
|
316
316
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
317
317
|
Please see the `Ascend tutorial
|
|
318
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
318
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
|
|
319
319
|
for more details.
|
|
320
320
|
|
|
321
321
|
For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
|
|
322
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
322
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
|
|
323
323
|
|
|
324
324
|
This example should be run with multiple devices.
|
|
325
325
|
|