mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,262 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""Transform DNN to BNN."""
|
|
16
|
-
import mindspore.nn as nn
|
|
17
|
-
from ...wrap.cell_wrapper import TrainOneStepCell
|
|
18
|
-
from ....nn import optim
|
|
19
|
-
from ....nn import layer
|
|
20
|
-
from ...probability import bnn_layers
|
|
21
|
-
from ..bnn_layers.bnn_cell_wrapper import WithBNNLossCell
|
|
22
|
-
from ..bnn_layers.conv_variational import ConvReparam
|
|
23
|
-
from ..bnn_layers.dense_variational import DenseReparam
|
|
24
|
-
|
|
25
|
-
__all__ = ['TransformToBNN']
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class TransformToBNN:
|
|
29
|
-
r"""
|
|
30
|
-
Transform Deep Neural Network (DNN) model to Bayesian Neural Network (BNN) model.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell.
|
|
34
|
-
dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1.
|
|
35
|
-
bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1.
|
|
36
|
-
|
|
37
|
-
Supported Platforms:
|
|
38
|
-
``Ascend`` ``GPU``
|
|
39
|
-
|
|
40
|
-
Examples:
|
|
41
|
-
>>> from mindspore.nn.probability import bnn_layers
|
|
42
|
-
>>>
|
|
43
|
-
>>> class Net(nn.Cell):
|
|
44
|
-
... def __init__(self):
|
|
45
|
-
... super(Net, self).__init__()
|
|
46
|
-
... self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
|
47
|
-
... self.bn = nn.BatchNorm2d(64)
|
|
48
|
-
... self.relu = nn.ReLU()
|
|
49
|
-
... self.flatten = nn.Flatten()
|
|
50
|
-
... self.fc = nn.Dense(64*224*224, 12) # padding=0
|
|
51
|
-
...
|
|
52
|
-
... def construct(self, x):
|
|
53
|
-
... x = self.conv(x)
|
|
54
|
-
... x = self.bn(x)
|
|
55
|
-
... x = self.relu(x)
|
|
56
|
-
... x = self.flatten(x)
|
|
57
|
-
... out = self.fc(x)
|
|
58
|
-
... return out
|
|
59
|
-
>>>
|
|
60
|
-
>>> net = Net()
|
|
61
|
-
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
62
|
-
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
|
|
63
|
-
>>> net_with_loss = nn.WithLossCell(net, criterion)
|
|
64
|
-
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
|
|
65
|
-
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001)
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
|
69
|
-
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
|
70
|
-
raise TypeError('The type of `dnn_factor` must be `int` or `float`')
|
|
71
|
-
if dnn_factor < 0:
|
|
72
|
-
raise ValueError('The value of `dnn_factor` must >= 0')
|
|
73
|
-
|
|
74
|
-
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
|
75
|
-
raise TypeError('The type of `bnn_factor` must be `int` or `float`')
|
|
76
|
-
if bnn_factor < 0:
|
|
77
|
-
raise ValueError('The value of `bnn_factor` must >= 0')
|
|
78
|
-
|
|
79
|
-
net_with_loss = trainable_dnn.network
|
|
80
|
-
self.optimizer = trainable_dnn.optimizer
|
|
81
|
-
self.backbone = net_with_loss.backbone_network
|
|
82
|
-
self.loss_fn = getattr(net_with_loss, "_loss_fn")
|
|
83
|
-
self.dnn_factor = dnn_factor
|
|
84
|
-
self.bnn_factor = bnn_factor
|
|
85
|
-
|
|
86
|
-
def transform_to_bnn_model(self,
|
|
87
|
-
get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias,
|
|
88
|
-
"out_channels": dp.out_channels,
|
|
89
|
-
"activation": dp.activation},
|
|
90
|
-
get_conv_args=lambda dp: {"in_channels": dp.in_channels,
|
|
91
|
-
"out_channels": dp.out_channels,
|
|
92
|
-
"pad_mode": dp.pad_mode, "kernel_size": dp.kernel_size,
|
|
93
|
-
"stride": dp.stride, "has_bias": dp.has_bias,
|
|
94
|
-
"padding": dp.padding, "dilation": dp.dilation,
|
|
95
|
-
"group": dp.group},
|
|
96
|
-
add_dense_args=None,
|
|
97
|
-
add_conv_args=None):
|
|
98
|
-
r"""
|
|
99
|
-
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
|
|
100
|
-
|
|
101
|
-
Args:
|
|
102
|
-
get_dense_args: The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
|
103
|
-
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
|
|
104
|
-
get_conv_args: The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
|
105
|
-
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
|
|
106
|
-
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
|
|
107
|
-
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
|
|
108
|
-
`add_dense_args` must not duplicate arguments in `get_dense_args`. Default: None.
|
|
109
|
-
add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in
|
|
110
|
-
`add_conv_args` must not duplicate arguments in `get_conv_args`. Default: None.
|
|
111
|
-
|
|
112
|
-
Returns:
|
|
113
|
-
Cell, a trainable BNN model wrapped by TrainOneStepCell.
|
|
114
|
-
|
|
115
|
-
Supported Platforms:
|
|
116
|
-
``Ascend`` ``GPU``
|
|
117
|
-
|
|
118
|
-
Examples:
|
|
119
|
-
>>> net = Net()
|
|
120
|
-
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
121
|
-
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
|
|
122
|
-
>>> net_with_loss = nn.WithLossCell(net, criterion)
|
|
123
|
-
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
|
|
124
|
-
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
|
125
|
-
>>> train_bnn_network = bnn_transformer.transform_to_bnn_model()
|
|
126
|
-
"""
|
|
127
|
-
if not add_dense_args:
|
|
128
|
-
add_dense_args = {}
|
|
129
|
-
if not add_conv_args:
|
|
130
|
-
add_conv_args = {}
|
|
131
|
-
|
|
132
|
-
self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args)
|
|
133
|
-
|
|
134
|
-
# rename layers of BNN model to prevent duplication of names
|
|
135
|
-
for value, param in self.backbone.parameters_and_names():
|
|
136
|
-
param.name = value
|
|
137
|
-
|
|
138
|
-
bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor)
|
|
139
|
-
bnn_optimizer = self._create_optimizer_with_bnn_params()
|
|
140
|
-
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
|
|
141
|
-
return train_bnn_network
|
|
142
|
-
|
|
143
|
-
def transform_to_bnn_layer(self, dnn_layer_type, bnn_layer_type, get_args=None, add_args=None):
|
|
144
|
-
r"""
|
|
145
|
-
Transform a specific type of layers in DNN model to corresponding BNN layer.
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
|
|
149
|
-
nn.Dense and nn.Conv2d.
|
|
150
|
-
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
|
|
151
|
-
DenseReparam and ConvReparam.
|
|
152
|
-
get_args: The arguments gotten from the DNN layer. Default: None.
|
|
153
|
-
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` must not
|
|
154
|
-
duplicate arguments in `get_args`. Default: None.
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the
|
|
158
|
-
corresponding bayesian layer.
|
|
159
|
-
|
|
160
|
-
Supported Platforms:
|
|
161
|
-
``Ascend`` ``GPU``
|
|
162
|
-
|
|
163
|
-
Examples:
|
|
164
|
-
>>> import mindspore.nn as nn
|
|
165
|
-
>>> from mindspore.nn.probability import bnn_layers
|
|
166
|
-
>>> net = LeNet()
|
|
167
|
-
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
168
|
-
>>> optim = nn.AdamWeightDecay(params=net.trainable_params(), learning_rate=0.0001)
|
|
169
|
-
>>> net_with_loss = nn.WithLossCell(net, criterion)
|
|
170
|
-
>>> train_network = nn.TrainOneStepCell(net_with_loss, optim)
|
|
171
|
-
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
|
172
|
-
>>> train_bnn_network = bnn_transformer.transform_to_bnn_layer(nn.Dense, bnn_layers.DenseReparam)
|
|
173
|
-
"""
|
|
174
|
-
if dnn_layer_type.__name__ not in ["Dense", "Conv2d"]:
|
|
175
|
-
raise ValueError(' \'dnn_layer\'' + str(dnn_layer_type) +
|
|
176
|
-
', should be one of values in \'nn.Dense\', \'nn.Conv2d\'.')
|
|
177
|
-
|
|
178
|
-
if bnn_layer_type.__name__ not in ["DenseReparam", "ConvReparam"]:
|
|
179
|
-
raise ValueError(' \'bnn_layer\'' + str(bnn_layer_type) +
|
|
180
|
-
', should be one of values in \'DenseReparam\', \'ConvReparam\'.')
|
|
181
|
-
|
|
182
|
-
dnn_layer_type = getattr(layer, dnn_layer_type.__name__)
|
|
183
|
-
bnn_layer_type = getattr(bnn_layers, bnn_layer_type.__name__)
|
|
184
|
-
|
|
185
|
-
if not get_args:
|
|
186
|
-
if dnn_layer_type.__name__ == "Dense":
|
|
187
|
-
get_args = self._get_dense_args
|
|
188
|
-
else:
|
|
189
|
-
get_args = self._get_conv_args
|
|
190
|
-
|
|
191
|
-
if not add_args:
|
|
192
|
-
add_args = {}
|
|
193
|
-
|
|
194
|
-
self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, add_args)
|
|
195
|
-
for value, param in self.backbone.parameters_and_names():
|
|
196
|
-
param.name = value
|
|
197
|
-
|
|
198
|
-
bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor)
|
|
199
|
-
bnn_optimizer = self._create_optimizer_with_bnn_params()
|
|
200
|
-
|
|
201
|
-
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
|
|
202
|
-
return train_bnn_network
|
|
203
|
-
|
|
204
|
-
def _get_dense_args(self, dense_layer):
|
|
205
|
-
"""Get arguments from dense layer."""
|
|
206
|
-
dense_args = {"in_channels": dense_layer.in_channels, "has_bias": dense_layer.has_bias,
|
|
207
|
-
"out_channels": dense_layer.out_channels, "activation": dense_layer.activation}
|
|
208
|
-
return dense_args
|
|
209
|
-
|
|
210
|
-
def _get_conv_args(self, conv_layer):
|
|
211
|
-
"""Get arguments from conv2d layer."""
|
|
212
|
-
conv_args = {"in_channels": conv_layer.in_channels, "out_channels": conv_layer.out_channels,
|
|
213
|
-
"pad_mode": conv_layer.pad_mode, "kernel_size": conv_layer.kernel_size,
|
|
214
|
-
"stride": conv_layer.stride, "has_bias": conv_layer.has_bias,
|
|
215
|
-
"padding": conv_layer.padding, "dilation": conv_layer.dilation,
|
|
216
|
-
"group": conv_layer.group}
|
|
217
|
-
return conv_args
|
|
218
|
-
|
|
219
|
-
def _create_optimizer_with_bnn_params(self):
|
|
220
|
-
"""Create new optimizer that contains bnn trainable parameters."""
|
|
221
|
-
name = self.optimizer.__class__.__name__
|
|
222
|
-
modules = optim.__all__
|
|
223
|
-
|
|
224
|
-
if name not in modules:
|
|
225
|
-
raise TypeError('The optimizer can be {}, but got {}'.format(str(modules), name))
|
|
226
|
-
|
|
227
|
-
optimizer = getattr(optim, name)
|
|
228
|
-
|
|
229
|
-
args = {'params': self.backbone.trainable_params()}
|
|
230
|
-
params = optimizer.__init__.__code__.co_varnames
|
|
231
|
-
_params = self.optimizer.__dict__['_params']
|
|
232
|
-
for param in params:
|
|
233
|
-
if param in _params:
|
|
234
|
-
args[param] = self.optimizer.__getattr__(param).data.asnumpy().tolist()
|
|
235
|
-
|
|
236
|
-
new_optimizer = optimizer(**args)
|
|
237
|
-
return new_optimizer
|
|
238
|
-
|
|
239
|
-
def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args):
|
|
240
|
-
"""Replace both dense layer and conv2d layer in DNN model to bayesian layers."""
|
|
241
|
-
for name, cell in backbone.name_cells().items():
|
|
242
|
-
if isinstance(cell, nn.Dense):
|
|
243
|
-
dense_args = get_dense_args(cell)
|
|
244
|
-
new_layer = DenseReparam(**dense_args, **add_dense_args)
|
|
245
|
-
setattr(backbone, name, new_layer)
|
|
246
|
-
elif isinstance(cell, nn.Conv2d):
|
|
247
|
-
conv_args = get_conv_args(cell)
|
|
248
|
-
new_layer = ConvReparam(**conv_args, **add_conv_args)
|
|
249
|
-
setattr(backbone, name, new_layer)
|
|
250
|
-
else:
|
|
251
|
-
self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args,
|
|
252
|
-
add_conv_args)
|
|
253
|
-
|
|
254
|
-
def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args):
|
|
255
|
-
"""Convert a specific type of layers in DNN model to corresponding bayesian layers."""
|
|
256
|
-
for name, cell in backbone.name_cells().items():
|
|
257
|
-
if isinstance(cell, dnn_layer):
|
|
258
|
-
args = get_args(cell)
|
|
259
|
-
new_layer = bnn_layer(**args, **add_args)
|
|
260
|
-
setattr(backbone, name, new_layer)
|
|
261
|
-
else:
|
|
262
|
-
self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args)
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
""" Zhusuan package: a probalistic programming library """
|
|
16
|
-
|
|
17
|
-
from .framework import BayesianNet
|
|
18
|
-
from .variational import ELBO
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
""" Core functionality for Zhusuan """
|
|
17
|
-
|
|
18
|
-
from .bn import BayesianNet
|
|
@@ -1,95 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
""" Bayesian Network """
|
|
16
|
-
|
|
17
|
-
import mindspore.nn as nn
|
|
18
|
-
|
|
19
|
-
import mindspore.nn.probability.distribution as msd
|
|
20
|
-
from mindspore.common import dtype as mstype
|
|
21
|
-
from mindspore.ops import operations as P
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class BayesianNet(nn.Cell):
|
|
25
|
-
"""
|
|
26
|
-
We currently support 3 types of variables: x = observation, z = latent, y = condition.
|
|
27
|
-
A Bayeisian Network models a generative process for certain variables: p(x,z|y) or p(z|x,y) or p(x|z,y)
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
def __init__(self):
|
|
31
|
-
super().__init__()
|
|
32
|
-
self.normal_dist = msd.Normal(dtype=mstype.float32)
|
|
33
|
-
self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32)
|
|
34
|
-
|
|
35
|
-
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
36
|
-
|
|
37
|
-
def normal(self,
|
|
38
|
-
name,
|
|
39
|
-
observation=None,
|
|
40
|
-
mean=None,
|
|
41
|
-
std=None,
|
|
42
|
-
seed=0,
|
|
43
|
-
dtype=mstype.float32,
|
|
44
|
-
shape=(),
|
|
45
|
-
reparameterize=True):
|
|
46
|
-
""" Normal distribution wrapper """
|
|
47
|
-
|
|
48
|
-
if not isinstance(name, str):
|
|
49
|
-
raise TypeError("The type of `name` must be string")
|
|
50
|
-
|
|
51
|
-
if observation is None:
|
|
52
|
-
if reparameterize:
|
|
53
|
-
epsilon = self.normal_dist('sample', shape, self.zeros(
|
|
54
|
-
mean.shape), self.ones(std.shape))
|
|
55
|
-
sample = mean + std * epsilon
|
|
56
|
-
else:
|
|
57
|
-
sample = self.normal_dist('sample', shape, mean, std)
|
|
58
|
-
else:
|
|
59
|
-
sample = observation
|
|
60
|
-
|
|
61
|
-
log_prob = self.reduce_sum(self.normal_dist(
|
|
62
|
-
'log_prob', sample, mean, std), 1)
|
|
63
|
-
return sample, log_prob
|
|
64
|
-
|
|
65
|
-
def bernoulli(self,
|
|
66
|
-
name,
|
|
67
|
-
observation=None,
|
|
68
|
-
probs=None,
|
|
69
|
-
seed=0,
|
|
70
|
-
dtype=mstype.float32,
|
|
71
|
-
shape=()):
|
|
72
|
-
""" Bernoulli distribution wrapper """
|
|
73
|
-
|
|
74
|
-
if not isinstance(name, str):
|
|
75
|
-
raise TypeError("The type of `name` must be string")
|
|
76
|
-
|
|
77
|
-
if observation is None:
|
|
78
|
-
sample = self.bernoulli_dist('sample', shape, probs)
|
|
79
|
-
else:
|
|
80
|
-
sample = observation
|
|
81
|
-
|
|
82
|
-
log_prob = self.reduce_sum(
|
|
83
|
-
self.bernoulli_dist('log_prob', sample, probs), 1)
|
|
84
|
-
return sample, log_prob
|
|
85
|
-
|
|
86
|
-
def construct(self, *inputs, **kwargs):
|
|
87
|
-
"""
|
|
88
|
-
We currently fix the parameters of the construct function.
|
|
89
|
-
Args:
|
|
90
|
-
the inputs must consist of 3 variables in order.
|
|
91
|
-
x: data sample, observation
|
|
92
|
-
z: latent variable
|
|
93
|
-
y: conditional information
|
|
94
|
-
"""
|
|
95
|
-
raise NotImplementedError
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
""" Variational inference related codes """
|
|
17
|
-
|
|
18
|
-
from .elbo import ELBO
|
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
""" ELBO """
|
|
16
|
-
|
|
17
|
-
import mindspore.nn as nn
|
|
18
|
-
|
|
19
|
-
from mindspore.ops import operations as P
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class ELBO(nn.Cell):
|
|
23
|
-
""" ELBO class """
|
|
24
|
-
def __init__(self, generator, variational):
|
|
25
|
-
super().__init__()
|
|
26
|
-
self.generator = generator
|
|
27
|
-
self.variational = variational
|
|
28
|
-
self.reshape_op = P.Reshape()
|
|
29
|
-
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
|
30
|
-
self.square = P.Square()
|
|
31
|
-
|
|
32
|
-
def construct(self, *inputs, **kwargs):
|
|
33
|
-
if len(inputs) >= 2:
|
|
34
|
-
x, y = inputs[0], inputs[1]
|
|
35
|
-
elif len(inputs) >= 1:
|
|
36
|
-
x = inputs[0]
|
|
37
|
-
y = None
|
|
38
|
-
else:
|
|
39
|
-
x = None
|
|
40
|
-
y = None
|
|
41
|
-
|
|
42
|
-
z, log_prob_z = self.variational(x, None, y)
|
|
43
|
-
_, log_prob_x_, _, log_prob_z_ = self.generator(x, z, y)
|
|
44
|
-
|
|
45
|
-
elbo = self.reduce_mean(log_prob_x_) + self.reduce_mean(log_prob_z_) - self.reduce_mean(log_prob_z)
|
|
46
|
-
return -elbo
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""ParallelConcat op"""
|
|
17
|
-
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
|
-
|
|
19
|
-
parallel_concat_op_info = AiCPURegOp("ParallelConcat") \
|
|
20
|
-
.fusion_type("OPAQUE") \
|
|
21
|
-
.input(0, "x", "dynamic") \
|
|
22
|
-
.output(0, "y", "required") \
|
|
23
|
-
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
|
24
|
-
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
|
25
|
-
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
|
26
|
-
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
|
27
|
-
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
|
28
|
-
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
|
29
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
|
30
|
-
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
|
31
|
-
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
|
32
|
-
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
|
33
|
-
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
|
34
|
-
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
|
|
35
|
-
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
|
|
36
|
-
.get_op_info()
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@op_info_register(parallel_concat_op_info)
|
|
40
|
-
def _parallel_concat_aicpu():
|
|
41
|
-
"""ParallelConcat AiCPU register"""
|
|
42
|
-
return
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""GatherV2 op"""
|
|
17
|
-
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
18
|
-
|
|
19
|
-
gather_v2_op_info = TBERegOp("Gather") \
|
|
20
|
-
.fusion_type("OPAQUE") \
|
|
21
|
-
.async_flag(False) \
|
|
22
|
-
.binfile_name("gather_v2_d.so") \
|
|
23
|
-
.compute_cost(10) \
|
|
24
|
-
.kernel_name("gather_v2_d") \
|
|
25
|
-
.partial_flag(True) \
|
|
26
|
-
.attr("axis", "required", "int", "all") \
|
|
27
|
-
.input(0, "x", False, "required", "all") \
|
|
28
|
-
.input(1, "indices", False, "required", "all") \
|
|
29
|
-
.output(0, "y", False, "required", "all") \
|
|
30
|
-
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
|
31
|
-
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
|
32
|
-
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
|
33
|
-
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
|
34
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
|
35
|
-
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
|
36
|
-
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
|
37
|
-
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
|
38
|
-
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
|
39
|
-
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
|
40
|
-
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
|
41
|
-
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
|
42
|
-
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
|
43
|
-
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
|
44
|
-
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
|
45
|
-
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
|
46
|
-
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
|
47
|
-
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
|
48
|
-
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
|
49
|
-
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
|
50
|
-
.get_op_info()
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@op_info_register(gather_v2_op_info)
|
|
54
|
-
def _gather_v2_tbe():
|
|
55
|
-
"""GatherV2 TBE register"""
|
|
56
|
-
return
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
0.1.1 MindSpore*2.0.0:�
|
|
3
|
-
}'get_bprop_assign_add.1231:[CNode]1232:1'get_bprop_assign_add.1231:[CNode]1232:1"REF::bprop.1233:Default/bprop.1233-op958get_bprop_assign_add.1231*
|
|
4
|
-
get_bprop_assign_add.1231:self*
|
|
5
|
-
get_bprop_assign_add.1231:x*
|
|
6
|
-
get_bprop_assign_add.1231:y*
|
|
7
|
-
get_bprop_assign_add.1231:out*
|
|
8
|
-
get_bprop_assign_add.1231:dout2)
|
|
9
|
-
'get_bprop_assign_add.1231:[CNode]1232:1:@864154f4834e62d84d34aab9399558528d5e734f6725d5daf7fbc1907cb32a1aJ/grad_math_ops.pyB�
|
|
10
|
-
�
|
|
11
|
-
get_bprop_assign_add.1231:xbprop.1233:[CNode]1234:2bprop.1233:[CNode]1234:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op959
|
|
12
|
-
�
|
|
13
|
-
get_bprop_assign_add.1231:ybprop.1233:[CNode]1235:3bprop.1233:[CNode]1235:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op960
|
|
14
|
-
�
|
|
15
|
-
bprop.1233:[CNode]1234:2
|
|
16
|
-
bprop.1233:[CNode]1235:3bprop.1233:[CNode]1236:4bprop.1233:[CNode]1236:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op961
|
|
17
|
-
bprop.12332
|
|
18
|
-
bprop.1233:[CNode]1236:4Pb&
|
|
19
|
-
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
0.1.1 MindSpore*2.0.0:�
|
|
3
|
-
]get_bprop_cast.6:[CNode]7:1get_bprop_cast.6:[CNode]7:1"REF::bprop.8:Default/bprop.8-op3get_bprop_cast.6*
|
|
4
|
-
get_bprop_cast.6:self*
|
|
5
|
-
get_bprop_cast.6:x*
|
|
6
|
-
get_bprop_cast.6:t*
|
|
7
|
-
get_bprop_cast.6:out*
|
|
8
|
-
get_bprop_cast.6:dout2
|
|
9
|
-
get_bprop_cast.6:[CNode]7:1:@2a049f3579950913c6ea42bb677f44470016652aa549a6dee2350ea48d50f039J/grad_array_ops.pyB�
|
|
10
|
-
�
|
|
11
|
-
get_bprop_cast.6:dout
|
|
12
|
-
get_bprop_cast.6:xbprop.8:dx:2bprop.8:dx:2"REF::MetaFuncGraph::dout_cast:Default/S-Prim-dout_cast-op4
|
|
13
|
-
�
|
|
14
|
-
get_bprop_cast.6:tbprop.8:[CNode]9:3bprop.8:[CNode]9:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
|
|
15
|
-
�
|
|
16
|
-
bprop.8:dx:2
|
|
17
|
-
bprop.8:[CNode]9:3bprop.8:[CNode]10:4bprop.8:[CNode]10:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op6bprop.82
|
|
18
|
-
bprop.8:[CNode]10:4Pb&
|
|
19
|
-
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
0.1.1 MindSpore*2.0.0:�
|
|
3
|
-
}'get_bprop_logical_or.1225:[CNode]1226:1'get_bprop_logical_or.1225:[CNode]1226:1"REF::bprop.1227:Default/bprop.1227-op954get_bprop_logical_or.1225*
|
|
4
|
-
get_bprop_logical_or.1225:self*
|
|
5
|
-
get_bprop_logical_or.1225:x*
|
|
6
|
-
get_bprop_logical_or.1225:y*
|
|
7
|
-
get_bprop_logical_or.1225:out*
|
|
8
|
-
get_bprop_logical_or.1225:dout2)
|
|
9
|
-
'get_bprop_logical_or.1225:[CNode]1226:1:@906051cca7d6d4b88a09a10b80bb5f0541066115667786dd7364cba0508be483J/grad_math_ops.pyB�
|
|
10
|
-
�
|
|
11
|
-
get_bprop_logical_or.1225:xbprop.1227:[CNode]1228:2bprop.1227:[CNode]1228:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op955
|
|
12
|
-
�
|
|
13
|
-
get_bprop_logical_or.1225:ybprop.1227:[CNode]1229:3bprop.1227:[CNode]1229:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op956
|
|
14
|
-
�
|
|
15
|
-
bprop.1227:[CNode]1228:2
|
|
16
|
-
bprop.1227:[CNode]1229:3bprop.1227:[CNode]1230:4bprop.1227:[CNode]1230:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op957
|
|
17
|
-
bprop.12272
|
|
18
|
-
bprop.1227:[CNode]1230:4Pb&
|
|
19
|
-
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
Binary file
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
0.1.1 MindSpore*2.0.0:�
|
|
3
|
-
q!get_bprop_relu.1273:[CNode]1274:1!get_bprop_relu.1273:[CNode]1274:1"REF::bprop.1275:Default/bprop.1275-op987
|
|
4
|
-
� get_bprop_relu.1273:input_grad:2 get_bprop_relu.1273:input_grad:2";REF::ClassType::mindspore.ops.operations._grad_ops.ReluGrad:ADefault/class 'mindspore.ops.operations._grad_ops.ReluGrad'-op988get_bprop_relu.1273*
|
|
5
|
-
get_bprop_relu.1273:self*
|
|
6
|
-
get_bprop_relu.1273:x*
|
|
7
|
-
get_bprop_relu.1273:out*
|
|
8
|
-
get_bprop_relu.1273:dout2#
|
|
9
|
-
!get_bprop_relu.1273:[CNode]1274:1:@92fe953ca276bba8b43b979895a017d6283e879df08b1a28bbae49b041d910b5J/grad_nn_ops.pyB�
|
|
10
|
-
�
|
|
11
|
-
get_bprop_relu.1273:dout
|
|
12
|
-
get_bprop_relu.1273:outbprop.1275:dx:3bprop.1275:dx:3"%REF::get_bprop_relu.1273:input_grad:2:989
|
|
13
|
-
~
|
|
14
|
-
bprop.1275:dx:3bprop.1275:[CNode]1276:4bprop.1275:[CNode]1276:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op990
|
|
15
|
-
bprop.12752
|
|
16
|
-
bprop.1275:[CNode]1276:4Pb&
|
|
17
|
-
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
Binary file
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
0.1.1 MindSpore*2.0.0:�
|
|
3
|
-
�
|
|
4
|
-
bprop_update_state.1331:u_monad%bprop_update_state.1331:[CNode]1332:1%bprop_update_state.1331:[CNode]1332:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:0Default/S-Prim-hyper_map[zeros_like_leaf]-op1032
|
|
5
|
-
�
|
|
6
|
-
bprop_update_state.1331:x%bprop_update_state.1331:[CNode]1333:2%bprop_update_state.1331:[CNode]1333:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:0Default/S-Prim-hyper_map[zeros_like_leaf]-op1033
|
|
7
|
-
�
|
|
8
|
-
%bprop_update_state.1331:[CNode]1332:1
|
|
9
|
-
%bprop_update_state.1331:[CNode]1333:2%bprop_update_state.1331:[CNode]1334:3%bprop_update_state.1331:[CNode]1334:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op1034bprop_update_state.1331*!
|
|
10
|
-
bprop_update_state.1331:u_monad*
|
|
11
|
-
bprop_update_state.1331:x*
|
|
12
|
-
bprop_update_state.1331:out*
|
|
13
|
-
bprop_update_state.1331:dout2'
|
|
14
|
-
%bprop_update_state.1331:[CNode]1334:3:@adde661b2e2f38680f80f0269960db717f31a852326a1f3bda24e010ea8bb153J/grad_implementations.pyPb&
|
|
15
|
-
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|