mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,23 +14,31 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""Defines gradient related operators with functional form."""
|
|
17
|
-
|
|
18
17
|
from __future__ import absolute_import
|
|
19
18
|
from functools import partial
|
|
20
|
-
|
|
19
|
+
import numpy as np
|
|
20
|
+
from mindspore.common import jit, mutable
|
|
21
21
|
from mindspore.common import Tensor
|
|
22
22
|
from mindspore.common import dtype as mstype
|
|
23
|
-
from mindspore.nn.
|
|
24
|
-
from mindspore.nn.grad.cell_grad import _VjpInner
|
|
23
|
+
from mindspore.nn.cell import Cell
|
|
25
24
|
from mindspore.nn.grad.cell_grad import _LinearizeInner
|
|
26
25
|
from mindspore.ops.primitive import constexpr
|
|
27
|
-
from mindspore.ops.function import ones, expand_dims
|
|
28
|
-
from mindspore.ops.composite import _Grad, _TaylorOperation
|
|
26
|
+
from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose
|
|
27
|
+
from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
|
|
29
28
|
from mindspore.ops import operations as P
|
|
29
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
30
30
|
|
|
31
31
|
cast = P.Cast()
|
|
32
32
|
dtype = P.DType()
|
|
33
33
|
zeros = P.Zeros()
|
|
34
|
+
oneslike = P.OnesLike()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@constexpr
|
|
38
|
+
def _check_has_aux_type(inputs):
|
|
39
|
+
if not isinstance(inputs, bool):
|
|
40
|
+
raise TypeError("The 'has_aux' must be bool type.")
|
|
41
|
+
return True
|
|
34
42
|
|
|
35
43
|
|
|
36
44
|
@constexpr
|
|
@@ -38,15 +46,27 @@ def _raise_type_error():
|
|
|
38
46
|
raise TypeError("The inputs type must be a Tensor, tuple or list of Tensors.")
|
|
39
47
|
|
|
40
48
|
|
|
49
|
+
@constexpr
|
|
50
|
+
def _check_duplicate_grad_position(grad_position):
|
|
51
|
+
"""Check if `grad_position` has duplicate positions when `grad_position` has more than one numbers."""
|
|
52
|
+
if len(set(grad_position)) != len(grad_position):
|
|
53
|
+
raise ValueError("There are duplicate positions in `grad_position`, please check it")
|
|
54
|
+
|
|
55
|
+
|
|
41
56
|
@constexpr
|
|
42
57
|
def _convert_grad_position_type(grad_position):
|
|
43
58
|
"""Check and convert the type and size of grad position index."""
|
|
44
59
|
if isinstance(grad_position, tuple):
|
|
45
|
-
|
|
60
|
+
_check_duplicate_grad_position(grad_position)
|
|
61
|
+
_grad_position = list(grad_position)
|
|
62
|
+
for i, gp in enumerate(_grad_position):
|
|
63
|
+
if isinstance(gp, bool):
|
|
64
|
+
_grad_position[i] = int(gp)
|
|
46
65
|
if not isinstance(gp, int):
|
|
47
66
|
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
|
|
48
67
|
if gp < 0:
|
|
49
68
|
raise ValueError("The element in grad_position must be >= 0.")
|
|
69
|
+
grad_position = tuple(_grad_position)
|
|
50
70
|
elif isinstance(grad_position, int):
|
|
51
71
|
if grad_position < 0:
|
|
52
72
|
raise ValueError("grad_position must be >= 0.")
|
|
@@ -57,11 +77,22 @@ def _convert_grad_position_type(grad_position):
|
|
|
57
77
|
|
|
58
78
|
|
|
59
79
|
@constexpr
|
|
60
|
-
def
|
|
61
|
-
|
|
80
|
+
def _check_grad_position(grad_position, args_num):
|
|
81
|
+
"""Check and convert grad position index."""
|
|
82
|
+
grad_position = _convert_grad_position_type(grad_position)
|
|
83
|
+
for gp in grad_position:
|
|
84
|
+
if gp < 0 or gp >= args_num:
|
|
85
|
+
raise ValueError("The element in grad_position must belong to [0, args_num).")
|
|
86
|
+
return grad_position
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@constexpr
|
|
90
|
+
def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False, return_ids=False):
|
|
91
|
+
return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value,
|
|
92
|
+
return_ids=return_ids)
|
|
62
93
|
|
|
63
94
|
|
|
64
|
-
def grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
95
|
+
def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
|
|
65
96
|
"""
|
|
66
97
|
A wrapper function to generate the gradient function for the input function.
|
|
67
98
|
|
|
@@ -84,11 +115,19 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
84
115
|
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
|
85
116
|
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
|
86
117
|
Default: False.
|
|
118
|
+
return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
|
|
119
|
+
to be differentiated or the name of parameters of the training network that need to calculate the gradient.
|
|
120
|
+
If True, the output gradients will be replaced by the tuples made by gradients and the index to specify
|
|
121
|
+
which inputs to be differentiated or the name of parameters of the training network.
|
|
122
|
+
Default: False.
|
|
87
123
|
|
|
88
124
|
Returns:
|
|
89
125
|
Function, the gradient function to calculate gradient for the input function or cell.
|
|
90
126
|
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
|
|
91
127
|
like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
|
|
128
|
+
When return_ids is set to True, The format of the output will be the same with the output of grad when
|
|
129
|
+
return_ids is set to false, but every gradient in the output will be replaced by a tuple of position id or
|
|
130
|
+
parameter name and its gradient.
|
|
92
131
|
|
|
93
132
|
Raises:
|
|
94
133
|
ValueError: If both `grad_position` and `weights` are None.
|
|
@@ -102,7 +141,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
102
141
|
>>> import mindspore
|
|
103
142
|
>>> import mindspore.nn as nn
|
|
104
143
|
>>> from mindspore import Tensor, ops
|
|
105
|
-
>>> from mindspore
|
|
144
|
+
>>> from mindspore import grad
|
|
106
145
|
>>>
|
|
107
146
|
>>> # Cell object to be differentiated
|
|
108
147
|
>>> class Net(nn.Cell):
|
|
@@ -131,7 +170,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
131
170
|
>>> print(aux)
|
|
132
171
|
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
|
|
133
172
|
>>>
|
|
134
|
-
>>> # For given network to be differentiated with both inputs and weights, there are
|
|
173
|
+
>>> # For given network to be differentiated with both inputs and weights, there are 4 cases.
|
|
135
174
|
>>> net = nn.Dense(10, 1)
|
|
136
175
|
>>> loss_fn = nn.MSELoss()
|
|
137
176
|
>>> def forward(inputs, labels):
|
|
@@ -163,17 +202,36 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
163
202
|
>>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
|
|
164
203
|
>>> print(len(weights), len(params_gradient))
|
|
165
204
|
2 2
|
|
205
|
+
>>> # Case 4: return the gradient with ids.
|
|
206
|
+
>>> import numpy as np
|
|
207
|
+
>>> import mindspore
|
|
208
|
+
>>> import mindspore.nn as nn
|
|
209
|
+
>>> from mindspore import Tensor, ops
|
|
210
|
+
>>> from mindspore import grad
|
|
211
|
+
>>>
|
|
212
|
+
>>> # Cell object to be differentiated
|
|
213
|
+
>>> class Net(nn.Cell):
|
|
214
|
+
... def construct(self, x, y, z):
|
|
215
|
+
... return x * y * z
|
|
216
|
+
>>> x = Tensor([1, 2], mindspore.float32)
|
|
217
|
+
>>> y = Tensor([-2, 3], mindspore.float32)
|
|
218
|
+
>>> z = Tensor([0, 3], mindspore.float32)
|
|
219
|
+
>>> net = Net()
|
|
220
|
+
>>> output = grad(net, grad_position=(1, 2), return_ids = True)(x, y, z)
|
|
221
|
+
>>> print(output)
|
|
222
|
+
((1, Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00])),
|
|
223
|
+
(2, Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00])))
|
|
166
224
|
"""
|
|
167
225
|
if grad_position is None and weights is None:
|
|
168
226
|
raise ValueError("`grad_position` and `weight` can not be None at the same time.")
|
|
169
227
|
|
|
170
228
|
if grad_position is None:
|
|
171
|
-
return _get_grad_op(True, False, has_aux)(fn, weights)
|
|
229
|
+
return _get_grad_op(True, False, has_aux, False, return_ids)(fn, weights)
|
|
172
230
|
|
|
173
231
|
grad_position = _convert_grad_position_type(grad_position)
|
|
174
232
|
if weights is None:
|
|
175
|
-
return _get_grad_op(False, True, has_aux)(fn, None, grad_position)
|
|
176
|
-
return _get_grad_op(True, True, has_aux)(fn, weights, grad_position)
|
|
233
|
+
return _get_grad_op(False, True, has_aux, False, return_ids)(fn, None, grad_position)
|
|
234
|
+
return _get_grad_op(True, True, has_aux, False, return_ids)(fn, weights, grad_position)
|
|
177
235
|
|
|
178
236
|
|
|
179
237
|
def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
@@ -216,7 +274,7 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
216
274
|
>>> import numpy as np
|
|
217
275
|
>>> import mindspore
|
|
218
276
|
>>> from mindspore import Tensor, ops, nn
|
|
219
|
-
>>> from mindspore
|
|
277
|
+
>>> from mindspore import value_and_grad
|
|
220
278
|
>>>
|
|
221
279
|
>>> # Cell object to be differentiated
|
|
222
280
|
>>> class Net(nn.Cell):
|
|
@@ -300,6 +358,55 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
|
|
|
300
358
|
return _get_grad_op(True, True, has_aux, True)(fn, weights, grad_position)
|
|
301
359
|
|
|
302
360
|
|
|
361
|
+
def get_grad(gradients, identifier):
|
|
362
|
+
"""
|
|
363
|
+
When `return_ids` of :func:`mindspore.grad` is set to True, use its return value as gradients. Then find
|
|
364
|
+
the specific gradient from `gradients` according to `identifier` .
|
|
365
|
+
|
|
366
|
+
As for gradient, two typical cases are included:
|
|
367
|
+
|
|
368
|
+
1. `identifier` is the position of the specific tensor to get gradient.
|
|
369
|
+
2. `identifier` is a parameter of a network.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
gradients (Union[tuple[int, Tensor], tuple[tuple, tuple]]): The return value of :func:`mindspore.grad`
|
|
373
|
+
when `return_ids` is set to True.
|
|
374
|
+
identifier (Union[int, Parameter]): The position number of a tensor, or a parameter that is used in
|
|
375
|
+
:func:`mindspore.grad`.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
The gradient of the tensor on the position or in the parameter that specified by the `identifier`.
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
RuntimeError: If gradient is not found.
|
|
382
|
+
TypeError: If type of Args does not belong to required ones.
|
|
383
|
+
|
|
384
|
+
Supported Platforms:
|
|
385
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
386
|
+
|
|
387
|
+
Examples:
|
|
388
|
+
>>> import numpy as np
|
|
389
|
+
>>> import mindspore
|
|
390
|
+
>>> import mindspore.nn as nn
|
|
391
|
+
>>> from mindspore import Tensor, ops
|
|
392
|
+
>>> from mindspore import grad, get_grad
|
|
393
|
+
>>>
|
|
394
|
+
>>> # Cell object to be differentiated
|
|
395
|
+
>>> class Net(nn.Cell):
|
|
396
|
+
... def construct(self, x, y, z):
|
|
397
|
+
... return x * y * z
|
|
398
|
+
>>> x = Tensor([1, 2], mindspore.float32)
|
|
399
|
+
>>> y = Tensor([-2, 3], mindspore.float32)
|
|
400
|
+
>>> z = Tensor([0, 3], mindspore.float32)
|
|
401
|
+
>>> net = Net()
|
|
402
|
+
>>> out_grad = grad(net, grad_position=(1, 2), return_ids=True)(x, y, z)
|
|
403
|
+
>>> output = get_grad(out_grad, 1)
|
|
404
|
+
>>> print(output)
|
|
405
|
+
[0. 6.]
|
|
406
|
+
"""
|
|
407
|
+
return inner.GetGrad()(gradients, identifier)
|
|
408
|
+
|
|
409
|
+
|
|
303
410
|
def _trans_jet_inputs(primals_item, series_item):
|
|
304
411
|
"""Trans inputs of jet"""
|
|
305
412
|
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
|
|
@@ -376,15 +483,14 @@ def jet(fn, primals, series):
|
|
|
376
483
|
>>> import numpy as np
|
|
377
484
|
>>> import mindspore.nn as nn
|
|
378
485
|
>>> import mindspore as ms
|
|
379
|
-
>>> import mindspore.ops as
|
|
486
|
+
>>> import mindspore.ops as ops
|
|
380
487
|
>>> from mindspore import Tensor
|
|
381
|
-
>>> from mindspore.ops.functional import jet
|
|
382
488
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
383
489
|
>>> class Net(nn.Cell):
|
|
384
490
|
... def __init__(self):
|
|
385
491
|
... super().__init__()
|
|
386
|
-
... self.sin =
|
|
387
|
-
... self.exp =
|
|
492
|
+
... self.sin = ops.Sin()
|
|
493
|
+
... self.exp = ops.Exp()
|
|
388
494
|
... def construct(self, x):
|
|
389
495
|
... out1 = self.sin(x)
|
|
390
496
|
... out2 = self.exp(out1)
|
|
@@ -392,7 +498,7 @@ def jet(fn, primals, series):
|
|
|
392
498
|
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
393
499
|
>>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
|
|
394
500
|
>>> net = Net()
|
|
395
|
-
>>> out_primals, out_series = jet(net, primals, series)
|
|
501
|
+
>>> out_primals, out_series = ops.jet(net, primals, series)
|
|
396
502
|
>>> print(out_primals, out_series)
|
|
397
503
|
[[2.319777 2.4825778]
|
|
398
504
|
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
|
|
@@ -487,15 +593,14 @@ def derivative(fn, primals, order):
|
|
|
487
593
|
>>> import numpy as np
|
|
488
594
|
>>> import mindspore as ms
|
|
489
595
|
>>> import mindspore.nn as nn
|
|
490
|
-
>>> import mindspore.ops as
|
|
596
|
+
>>> import mindspore.ops as ops
|
|
491
597
|
>>> from mindspore import Tensor
|
|
492
|
-
>>> from mindspore.ops.functional import derivative
|
|
493
598
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
494
599
|
>>> class Net(nn.Cell):
|
|
495
600
|
... def __init__(self):
|
|
496
601
|
... super().__init__()
|
|
497
|
-
... self.sin =
|
|
498
|
-
... self.exp =
|
|
602
|
+
... self.sin = ops.Sin()
|
|
603
|
+
... self.exp = ops.Exp()
|
|
499
604
|
... def construct(self, x):
|
|
500
605
|
... out1 = self.sin(x)
|
|
501
606
|
... out2 = self.exp(out1)
|
|
@@ -503,7 +608,7 @@ def derivative(fn, primals, order):
|
|
|
503
608
|
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
504
609
|
>>> order = 3
|
|
505
610
|
>>> net = Net()
|
|
506
|
-
>>> out_primals, out_series = derivative(net, primals, order)
|
|
611
|
+
>>> out_primals, out_series = ops.derivative(net, primals, order)
|
|
507
612
|
>>> print(out_primals, out_series)
|
|
508
613
|
[[2.319777 2.4825778]
|
|
509
614
|
[1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
|
|
@@ -541,10 +646,20 @@ def derivative(fn, primals, order):
|
|
|
541
646
|
return out_primals, out_series
|
|
542
647
|
|
|
543
648
|
|
|
544
|
-
|
|
649
|
+
_grad_single = GradOperation(sens_param=True)
|
|
650
|
+
_grad_all = GradOperation(sens_param=True, get_all=True)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
@constexpr
|
|
654
|
+
def _check_jvp_input_v_len(inputs_len, v_len):
|
|
655
|
+
if inputs_len != v_len:
|
|
656
|
+
raise ValueError(f'v has invalid length: should be {inputs_len}, but got {v_len}')
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def jvp(fn, inputs, v, has_aux=False):
|
|
545
660
|
"""
|
|
546
661
|
Compute the jacobian-vector-product of the given network. `jvp` matches
|
|
547
|
-
`forward-mode differentiation <https://www.mindspore.cn/docs/en/
|
|
662
|
+
`forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
|
|
548
663
|
|
|
549
664
|
Args:
|
|
550
665
|
fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
|
|
@@ -552,10 +667,16 @@ def jvp(fn, inputs, v):
|
|
|
552
667
|
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
|
|
553
668
|
v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in jacobian-vector-product. The shape and type of `v`
|
|
554
669
|
should be the same as `inputs` .
|
|
670
|
+
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
|
671
|
+
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
|
672
|
+
Default: False.
|
|
555
673
|
|
|
556
674
|
Returns:
|
|
557
|
-
- **net_output** (Union[Tensor, tuple[Tensor]]) - The
|
|
675
|
+
- **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)` . Specially, when `has_aux` is set
|
|
676
|
+
True, `netout` is the first output of `fn(inputs)` .
|
|
558
677
|
- **jvp** (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
|
|
678
|
+
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
|
|
679
|
+
It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to gradient.
|
|
559
680
|
|
|
560
681
|
Raises:
|
|
561
682
|
TypeError: `inputs` or `v` does not belong to required types.
|
|
@@ -564,32 +685,102 @@ def jvp(fn, inputs, v):
|
|
|
564
685
|
``Ascend`` ``GPU`` ``CPU``
|
|
565
686
|
|
|
566
687
|
Examples:
|
|
567
|
-
>>>
|
|
688
|
+
>>> import numpy as np
|
|
689
|
+
>>> from mindspore import jvp
|
|
568
690
|
>>> from mindspore import Tensor
|
|
691
|
+
>>> import mindspore.nn as nn
|
|
569
692
|
>>> class Net(nn.Cell):
|
|
570
693
|
... def construct(self, x, y):
|
|
571
694
|
... return x**3 + y
|
|
572
695
|
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
573
696
|
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
574
697
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
|
575
|
-
>>> output =
|
|
698
|
+
>>> output = jvp(Net(), (x, y), (v, v))
|
|
576
699
|
>>> print(output[0])
|
|
577
700
|
[[ 2. 10.]
|
|
578
701
|
[30. 68.]]
|
|
579
702
|
>>> print(output[1])
|
|
580
703
|
[[ 4. 13.]
|
|
581
704
|
[28. 49.]]
|
|
705
|
+
>>>
|
|
706
|
+
>>> def fn(x, y):
|
|
707
|
+
... return x ** 3 + y, y
|
|
708
|
+
>>> output, jvp_out, aux = jvp(fn, (x, y), (v, v), has_aux=True)
|
|
709
|
+
>>> print(output)
|
|
710
|
+
[[ 2. 10.]
|
|
711
|
+
[30. 68.]]
|
|
712
|
+
>>> print(jvp_out)
|
|
713
|
+
[[ 4. 13.]
|
|
714
|
+
[28. 49.]]
|
|
715
|
+
>>> print(aux)
|
|
716
|
+
[[ 1. 2.]
|
|
717
|
+
[3. 4.]]
|
|
582
718
|
"""
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
719
|
+
_check_has_aux_type(has_aux)
|
|
720
|
+
|
|
721
|
+
def aux_fn(*args):
|
|
722
|
+
outputs = fn(*args)
|
|
723
|
+
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
724
|
+
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
|
|
725
|
+
res = outputs[0]
|
|
726
|
+
return res
|
|
727
|
+
|
|
728
|
+
def grad_single(u, first_grad_single_value):
|
|
729
|
+
if has_aux:
|
|
730
|
+
return _grad_single(aux_fn)(*first_grad_single_value, u)
|
|
731
|
+
return _grad_single(fn)(*first_grad_single_value, u)
|
|
732
|
+
|
|
733
|
+
def grad_all(u, first_grad):
|
|
734
|
+
if has_aux:
|
|
735
|
+
return _grad_all(aux_fn)(*first_grad, u)
|
|
736
|
+
return _grad_all(fn)(*first_grad, u)
|
|
737
|
+
|
|
738
|
+
def _wrap_container_inner(*arg):
|
|
739
|
+
jvp_inputs = arg[1:]
|
|
588
740
|
vectors = arg[0]
|
|
589
|
-
|
|
741
|
+
if has_aux:
|
|
742
|
+
outputs = aux_fn(*jvp_inputs)
|
|
743
|
+
else:
|
|
744
|
+
outputs = fn(*jvp_inputs)
|
|
745
|
+
if isinstance(outputs, tuple):
|
|
746
|
+
u = ()
|
|
747
|
+
for item in outputs:
|
|
748
|
+
u = u + (mutable(oneslike(item)),)
|
|
749
|
+
else:
|
|
750
|
+
u = mutable(oneslike(outputs))
|
|
751
|
+
if len(jvp_inputs) == 1:
|
|
752
|
+
second_grad_net = _grad_single(grad_single)
|
|
753
|
+
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
|
754
|
+
else:
|
|
755
|
+
second_grad_net = _grad_single(grad_all)
|
|
756
|
+
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
|
757
|
+
if has_aux:
|
|
758
|
+
res = fn(*jvp_inputs)
|
|
759
|
+
if len(res) == 2:
|
|
760
|
+
return res[0], gradient_outputs, res[1]
|
|
761
|
+
return res[0], gradient_outputs, res[1:]
|
|
762
|
+
return outputs, gradient_outputs
|
|
763
|
+
|
|
764
|
+
if has_aux:
|
|
765
|
+
@jit(hash_args=aux_fn)
|
|
766
|
+
def _wrap_container(*arg):
|
|
767
|
+
return _wrap_container_inner(*arg)
|
|
768
|
+
else:
|
|
769
|
+
@jit(hash_args=fn)
|
|
770
|
+
def _wrap_container(*arg):
|
|
771
|
+
return _wrap_container_inner(*arg)
|
|
590
772
|
|
|
591
773
|
if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
|
|
592
774
|
_raise_type_error()
|
|
775
|
+
|
|
776
|
+
inputs_len = 1
|
|
777
|
+
v_len = 1
|
|
778
|
+
if isinstance(inputs, (tuple, list)):
|
|
779
|
+
inputs_len = len(inputs)
|
|
780
|
+
if isinstance(v, (tuple, list)):
|
|
781
|
+
v_len = len(v)
|
|
782
|
+
_check_jvp_input_v_len(inputs_len, v_len)
|
|
783
|
+
|
|
593
784
|
if isinstance(v, list):
|
|
594
785
|
v = tuple(v)
|
|
595
786
|
if isinstance(inputs, (tuple, list)):
|
|
@@ -647,7 +838,7 @@ def linearize(fn, inputs):
|
|
|
647
838
|
"""
|
|
648
839
|
linearize_inner = _LinearizeInner()
|
|
649
840
|
|
|
650
|
-
@
|
|
841
|
+
@jit(hash_args=fn)
|
|
651
842
|
def _wrap_container(*arg):
|
|
652
843
|
args = arg[1:-1]
|
|
653
844
|
vectors = arg[-1]
|
|
@@ -664,24 +855,38 @@ def linearize(fn, inputs):
|
|
|
664
855
|
return output, partial(_wrap_container, output, *inputs)
|
|
665
856
|
|
|
666
857
|
|
|
667
|
-
def
|
|
858
|
+
def _check_tensor(inputs):
|
|
859
|
+
if not isinstance(inputs, (Tensor, tuple)):
|
|
860
|
+
raise TypeError("The inputs type must be Tensor.")
|
|
861
|
+
if isinstance(inputs, tuple):
|
|
862
|
+
for item in inputs:
|
|
863
|
+
if not isinstance(item, (Tensor, tuple, list)):
|
|
864
|
+
raise TypeError("The inputs type must be Tensor.")
|
|
865
|
+
return True
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def vjp(fn, *inputs, has_aux=False):
|
|
668
869
|
"""
|
|
669
870
|
Compute the vector-jacobian-product of the given network. `vjp` matches
|
|
670
|
-
`reverse-mode differentiation <https://www.mindspore.cn/docs/en/
|
|
671
|
-
|
|
672
|
-
Note:
|
|
673
|
-
This function is subjected to change in the future.
|
|
871
|
+
`reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
|
|
674
872
|
|
|
675
873
|
Args:
|
|
676
874
|
fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
|
|
677
875
|
Tensors.
|
|
678
876
|
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
|
|
679
|
-
|
|
680
|
-
|
|
877
|
+
has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
|
|
878
|
+
will be returned straightly. It means the `fn` must return more than one outputs in this case.
|
|
879
|
+
Default: False.
|
|
681
880
|
|
|
682
881
|
Returns:
|
|
683
|
-
|
|
684
|
-
|
|
882
|
+
Forward outputs and function to calculate vjp.
|
|
883
|
+
|
|
884
|
+
- **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)`. Specially, when `has_aux` is set
|
|
885
|
+
True, `netout` is the first output of `fn(inputs)`.
|
|
886
|
+
- **vjp_fn** (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and
|
|
887
|
+
type should be the same as `netout` .
|
|
888
|
+
- **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
|
|
889
|
+
It means the second to last outputs of `fn(inputs)`. Specially, `aux_value` does not contribute to gradient.
|
|
685
890
|
|
|
686
891
|
Raises:
|
|
687
892
|
TypeError: `inputs` or `v` does not belong to required types.
|
|
@@ -690,7 +895,9 @@ def vjp(fn, inputs, v):
|
|
|
690
895
|
``Ascend`` ``GPU`` ``CPU``
|
|
691
896
|
|
|
692
897
|
Examples:
|
|
693
|
-
>>>
|
|
898
|
+
>>> import numpy as np
|
|
899
|
+
>>> import mindspore.nn as nn
|
|
900
|
+
>>> from mindspore import vjp
|
|
694
901
|
>>> from mindspore import Tensor
|
|
695
902
|
>>> class Net(nn.Cell):
|
|
696
903
|
... def construct(self, x, y):
|
|
@@ -698,41 +905,505 @@ def vjp(fn, inputs, v):
|
|
|
698
905
|
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
699
906
|
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
700
907
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
|
701
|
-
>>>
|
|
702
|
-
>>> print(
|
|
908
|
+
>>> outputs, vjp_fn = vjp(Net(), x, y)
|
|
909
|
+
>>> print(outputs)
|
|
703
910
|
[[ 2. 10.]
|
|
704
911
|
[30. 68.]]
|
|
705
|
-
>>>
|
|
912
|
+
>>> gradient = vjp_fn(v)
|
|
913
|
+
>>> print(gradient)
|
|
706
914
|
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
707
915
|
[[ 3.00000000e+00, 1.20000000e+01],
|
|
708
916
|
[ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
709
917
|
[[ 1.00000000e+00, 1.00000000e+00],
|
|
710
918
|
[ 1.00000000e+00, 1.00000000e+00]]))
|
|
919
|
+
>>> def fn(x, y):
|
|
920
|
+
... return 2 * x + y, y ** 3
|
|
921
|
+
>>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
|
|
922
|
+
>>> gradient = vjp_fn(v)
|
|
923
|
+
>>> print(outputs)
|
|
924
|
+
[[ 3. 6.]
|
|
925
|
+
[ 9. 12.]]
|
|
926
|
+
>>> print(aux)
|
|
927
|
+
[[ 1. 8.]
|
|
928
|
+
[27. 64.]]
|
|
929
|
+
>>> print(gradient)
|
|
930
|
+
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
931
|
+
[[ 2.00000000e+00, 2.00000000e+00],
|
|
932
|
+
[ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
933
|
+
[[ 1.00000000e+00, 1.00000000e+00],
|
|
934
|
+
[ 1.00000000e+00, 1.00000000e+00]]))
|
|
711
935
|
"""
|
|
712
|
-
|
|
936
|
+
_check_tensor(inputs)
|
|
937
|
+
_check_has_aux_type(has_aux)
|
|
938
|
+
|
|
939
|
+
def aux_fn(*args):
|
|
940
|
+
outputs = fn(*args)
|
|
941
|
+
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
942
|
+
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
|
|
943
|
+
res = outputs[0]
|
|
944
|
+
return res
|
|
945
|
+
|
|
946
|
+
def wrap_container(*v):
|
|
947
|
+
_check_tensor(v)
|
|
948
|
+
if has_aux:
|
|
949
|
+
fn_ = aux_fn
|
|
950
|
+
else:
|
|
951
|
+
fn_ = fn
|
|
952
|
+
if len(v) == 1:
|
|
953
|
+
return _grad_all(fn_)(*inputs, v[0])
|
|
954
|
+
return _grad_all(fn_)(*inputs, v)
|
|
713
955
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
return
|
|
956
|
+
res = fn(*inputs)
|
|
957
|
+
if has_aux:
|
|
958
|
+
if len(res) == 2:
|
|
959
|
+
return res[0], wrap_container, res[1]
|
|
960
|
+
return res[0], wrap_container, res[1:]
|
|
961
|
+
return res, wrap_container
|
|
719
962
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
963
|
+
|
|
964
|
+
@constexpr
|
|
965
|
+
def _jac_generate_target_dimension(x):
|
|
966
|
+
"""For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
|
|
967
|
+
target_dimension = tuple(index + 1 for index, _ in enumerate(x[1:])) + (0,)
|
|
968
|
+
return target_dimension
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
def _jacfwd_trans_item(item, inputs_shape, grad_position):
|
|
972
|
+
"""transfer origin item to derivative of each output with respect to each input."""
|
|
973
|
+
output_wrt_input_all = ()
|
|
974
|
+
for i in grad_position:
|
|
975
|
+
origin_output_wrt_input = item[inputs_shape[i][1]:inputs_shape[i + 1][1]]
|
|
976
|
+
target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
|
|
977
|
+
temp = transpose(origin_output_wrt_input, target_dimension)
|
|
978
|
+
output_wrt_input = reshape(temp, temp.shape[:-1] + inputs_shape[i + 1][0])
|
|
979
|
+
output_wrt_input_all += (output_wrt_input,)
|
|
980
|
+
return output_wrt_input_all
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
def _jac_postprocess(x, shape, grad_position, mode):
|
|
984
|
+
"""reformat jacobian."""
|
|
985
|
+
|
|
986
|
+
if mode == 'forward':
|
|
987
|
+
func = _jacfwd_trans_item
|
|
988
|
+
args = (shape, grad_position)
|
|
989
|
+
else:
|
|
990
|
+
func = _jacrev_trans_item
|
|
991
|
+
args = (shape,)
|
|
992
|
+
|
|
993
|
+
if isinstance(x, tuple):
|
|
994
|
+
jacobian = ()
|
|
995
|
+
for item in x:
|
|
996
|
+
jacobian += func(item, *args)
|
|
997
|
+
res = jacobian
|
|
998
|
+
else:
|
|
999
|
+
res = func(x, *args)
|
|
1000
|
+
if len(res) == 1:
|
|
1001
|
+
return res[0]
|
|
1002
|
+
input_num = len(grad_position)
|
|
1003
|
+
if len(res) % input_num != 0:
|
|
1004
|
+
raise ValueError("The numbers of inputs and outputs do not match.")
|
|
1005
|
+
output_num = len(res) // input_num
|
|
1006
|
+
if input_num == 1 or output_num == 1:
|
|
1007
|
+
return res
|
|
1008
|
+
jac = ()
|
|
1009
|
+
for i in range(output_num):
|
|
1010
|
+
input_grad = ()
|
|
1011
|
+
for j in range(input_num):
|
|
1012
|
+
if mode == 'forward':
|
|
1013
|
+
grad_increment = (res[i * input_num + j],)
|
|
1014
|
+
else:
|
|
1015
|
+
grad_increment = (res[j * output_num + i],)
|
|
1016
|
+
input_grad += grad_increment
|
|
1017
|
+
jac += (input_grad,)
|
|
1018
|
+
return jac
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def _jacfwd_postprocess(x, inputs_shape, grad_position):
|
|
1022
|
+
"""reformat forward-computed Jacobian."""
|
|
1023
|
+
return _jac_postprocess(x, inputs_shape, grad_position, 'forward')
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def _jacfwd_construct_v(inputs, grad_position):
|
|
1027
|
+
"""
|
|
1028
|
+
For input (x1, x2), x1.shape = (a, b), x2.shape = (c, d), this method generates corresponding v (v1, v2),
|
|
1029
|
+
v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
|
|
1030
|
+
"""
|
|
1031
|
+
v = ()
|
|
1032
|
+
primals = ()
|
|
1033
|
+
inputs_shape = (((), 0),)
|
|
1034
|
+
num = 0
|
|
1035
|
+
items_num = ()
|
|
1036
|
+
cum_num = (0,)
|
|
1037
|
+
for item in inputs:
|
|
1038
|
+
num += size(item)
|
|
1039
|
+
inputs_shape += ((item.shape, num),)
|
|
1040
|
+
items_num += (size(item),)
|
|
1041
|
+
cum_num += (num,)
|
|
1042
|
+
for i, element in enumerate(inputs):
|
|
1043
|
+
item_size = items_num[i]
|
|
1044
|
+
if i in grad_position:
|
|
1045
|
+
temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
|
|
1046
|
+
else:
|
|
1047
|
+
temp2 = zeros((num, item_size), mstype.float32)
|
|
1048
|
+
input_v = reshape(temp2, (num,) + element.shape)
|
|
1049
|
+
primal = broadcast_to(element, (num,) + element.shape)
|
|
1050
|
+
v += (input_v,)
|
|
1051
|
+
primals += (primal,)
|
|
1052
|
+
if len(inputs) == 1:
|
|
1053
|
+
return primals, v[0], inputs_shape
|
|
1054
|
+
return primals, v, inputs_shape
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
_vmap = _Vmap()
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
def jacfwd(fn, grad_position=0, has_aux=False):
|
|
1061
|
+
"""
|
|
1062
|
+
Compute Jacobian via forward mode, corresponding to
|
|
1063
|
+
`forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
|
|
1064
|
+
When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
|
|
1065
|
+
reverse mode to get better performance.
|
|
1066
|
+
|
|
1067
|
+
Args:
|
|
1068
|
+
fn (Union[Cell, Function]): Function to do GradOperation.
|
|
1069
|
+
grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
|
|
1070
|
+
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1071
|
+
has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
|
|
1072
|
+
while the other outputs will be returned straightly. It means the `fn` must return more than one
|
|
1073
|
+
outputs in this case. Default: False.
|
|
1074
|
+
|
|
1075
|
+
Returns:
|
|
1076
|
+
Function, returns the Jacobian function for the input function or cell.
|
|
1077
|
+
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
|
|
1078
|
+
like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
|
|
1079
|
+
|
|
1080
|
+
Raises:
|
|
1081
|
+
TypeError: `grad_position` or `has_aux` does not belong to required types.
|
|
1082
|
+
|
|
1083
|
+
Supported Platforms:
|
|
1084
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1085
|
+
|
|
1086
|
+
Examples:
|
|
1087
|
+
>>> import numpy as np
|
|
1088
|
+
>>> import mindspore.nn as nn
|
|
1089
|
+
>>> from mindspore import jacfwd
|
|
1090
|
+
>>> from mindspore import Tensor
|
|
1091
|
+
>>> class MultipleInputsMultipleOutputsNet(nn.Cell):
|
|
1092
|
+
... def construct(self, x, y, z):
|
|
1093
|
+
... return x ** 2 + y ** 2 + z ** 2, x * y * z
|
|
1094
|
+
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
1095
|
+
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
1096
|
+
>>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
|
1097
|
+
>>> net = MultipleInputsMultipleOutputsNet()
|
|
1098
|
+
>>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
|
|
1099
|
+
>>> print(jac)
|
|
1100
|
+
[[[[ 2., 0.]
|
|
1101
|
+
[ 0., 0.]]
|
|
1102
|
+
[[ 0., 4.]
|
|
1103
|
+
[ 0., 0.]]]
|
|
1104
|
+
[[[ 0., 0.]
|
|
1105
|
+
[ 6., 0.]]
|
|
1106
|
+
[[ 0., 0.]
|
|
1107
|
+
[ 0., 8.]]]]
|
|
1108
|
+
>>> print(aux)
|
|
1109
|
+
[[ 1. 4.]
|
|
1110
|
+
[ 9. 16.]]
|
|
1111
|
+
"""
|
|
1112
|
+
_check_has_aux_type(has_aux)
|
|
1113
|
+
|
|
1114
|
+
def aux_fn(*args):
|
|
1115
|
+
outputs = fn(*args)
|
|
1116
|
+
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
1117
|
+
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
|
|
1118
|
+
res = outputs[0]
|
|
1119
|
+
return res
|
|
1120
|
+
|
|
1121
|
+
def grad_single(u, first_grad_single_value):
|
|
1122
|
+
if has_aux:
|
|
1123
|
+
return _grad_single(aux_fn)(*first_grad_single_value, u)
|
|
1124
|
+
return _grad_single(fn)(*first_grad_single_value, u)
|
|
1125
|
+
|
|
1126
|
+
def grad_all(u, first_grad):
|
|
1127
|
+
if has_aux:
|
|
1128
|
+
return _grad_all(aux_fn)(*first_grad, u)
|
|
1129
|
+
return _grad_all(fn)(*first_grad, u)
|
|
1130
|
+
|
|
1131
|
+
@jit
|
|
1132
|
+
def wrapped(*args):
|
|
1133
|
+
checked_grad_position = _check_grad_position(grad_position, len(args))
|
|
1134
|
+
primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
|
|
1135
|
+
|
|
1136
|
+
def inner_fn(jvp_inputs, vectors):
|
|
1137
|
+
outputs = fn(*jvp_inputs)
|
|
1138
|
+
if isinstance(outputs, tuple):
|
|
1139
|
+
u = ()
|
|
1140
|
+
for item in outputs:
|
|
1141
|
+
u = u + (mutable(oneslike(item)),)
|
|
1142
|
+
else:
|
|
1143
|
+
u = mutable(oneslike(outputs))
|
|
1144
|
+
if len(jvp_inputs) == 1:
|
|
1145
|
+
second_grad_net = _grad_single(grad_single)
|
|
1146
|
+
else:
|
|
1147
|
+
second_grad_net = _grad_single(grad_all)
|
|
1148
|
+
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
|
1149
|
+
return gradient_outputs
|
|
1150
|
+
|
|
1151
|
+
def inner_aux_fn(jvp_inputs, vectors):
|
|
1152
|
+
outputs = aux_fn(*jvp_inputs)
|
|
1153
|
+
u = mutable(oneslike(outputs))
|
|
1154
|
+
if len(jvp_inputs) == 1:
|
|
1155
|
+
second_grad_net = _grad_single(grad_single)
|
|
1156
|
+
else:
|
|
1157
|
+
second_grad_net = _grad_single(grad_all)
|
|
1158
|
+
gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
|
|
1159
|
+
return gradient_outputs
|
|
1160
|
+
|
|
1161
|
+
if has_aux:
|
|
1162
|
+
res = _vmap(inner_aux_fn)(primals, v)
|
|
1163
|
+
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
|
|
1164
|
+
forward_outputs = fn(*args)
|
|
1165
|
+
if len(forward_outputs) == 2:
|
|
1166
|
+
return jac_res, forward_outputs[1]
|
|
1167
|
+
return jac_res, forward_outputs[1:]
|
|
1168
|
+
res = _vmap(inner_fn)(primals, v)
|
|
1169
|
+
jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
|
|
1170
|
+
return jac_res
|
|
1171
|
+
|
|
1172
|
+
return wrapped
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
def _jacrev_trans_item(item, outputs_shape):
|
|
1176
|
+
"""transfer origin item to derivative of each output with respect to each input."""
|
|
1177
|
+
output_wrt_input_all = ()
|
|
1178
|
+
length = len(outputs_shape) - 1
|
|
1179
|
+
for i in range(length):
|
|
1180
|
+
origin_output_wrt_input = item[outputs_shape[i][1]:outputs_shape[i + 1][1]]
|
|
1181
|
+
target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
|
|
1182
|
+
temp = transpose(origin_output_wrt_input, target_dimension)
|
|
1183
|
+
output_wrt_input = reshape(origin_output_wrt_input, outputs_shape[i + 1][0] + temp.shape[:-1])
|
|
1184
|
+
output_wrt_input_all += (output_wrt_input,)
|
|
1185
|
+
return output_wrt_input_all
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
def _jacrev_postprocess(x, outputs_shape, grad_position):
|
|
1189
|
+
"""reformat reverse-computed jacobian."""
|
|
1190
|
+
return _jac_postprocess(x, outputs_shape, grad_position, 'reverse')
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
def _jacrev_construct_v(inputs, outputs, has_aux=False):
|
|
1194
|
+
"""
|
|
1195
|
+
For outputs (y1, y2), y1.shape = (a, b), y2.shape = (c, d), this method generates corresponding v (v1, v2),
|
|
1196
|
+
v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
|
|
1197
|
+
"""
|
|
1198
|
+
if isinstance(outputs, Tensor):
|
|
1199
|
+
outputs = (outputs,)
|
|
1200
|
+
if has_aux:
|
|
1201
|
+
outputs = (outputs[0],)
|
|
1202
|
+
v = ()
|
|
1203
|
+
primals = ()
|
|
1204
|
+
outputs_shape = (((), 0),)
|
|
1205
|
+
num = 0
|
|
1206
|
+
items_num = ()
|
|
1207
|
+
cum_num = (0,)
|
|
1208
|
+
for item in outputs:
|
|
1209
|
+
item_num = size(item)
|
|
1210
|
+
num += item_num
|
|
1211
|
+
outputs_shape += ((item.shape, num),)
|
|
1212
|
+
items_num += (item_num,)
|
|
1213
|
+
cum_num += (num,)
|
|
1214
|
+
for element in inputs:
|
|
1215
|
+
primal = broadcast_to(element, (num,) + element.shape)
|
|
1216
|
+
primals += (primal,)
|
|
1217
|
+
for i, element in enumerate(outputs):
|
|
1218
|
+
item_size = items_num[i]
|
|
1219
|
+
temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
|
|
1220
|
+
output_v = reshape(temp2, (num,) + element.shape)
|
|
1221
|
+
v += (output_v,)
|
|
1222
|
+
if len(outputs) == 1 or has_aux:
|
|
1223
|
+
return primals, v[0], outputs_shape
|
|
1224
|
+
return primals, v, outputs_shape
|
|
1225
|
+
|
|
1226
|
+
|
|
1227
|
+
_grad = _Grad(get_by_position=True, has_aux=False, sens_param=True)
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
def jacrev(fn, grad_position=0, has_aux=False):
|
|
1231
|
+
"""
|
|
1232
|
+
Compute Jacobian via reverse mode, corresponding to
|
|
1233
|
+
`reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
|
|
1234
|
+
When number of inputs is much greater than that of outputs, it's better to calculate Jacobian via reverse mode than
|
|
1235
|
+
forward mode to get better performance.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
fn (Union[Cell, Function]): Function to do GradOperation.
|
|
1239
|
+
grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
|
|
1240
|
+
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1241
|
+
has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
|
|
1242
|
+
while the other outputs will be returned straightly. It means the `fn` must return more than
|
|
1243
|
+
one outputs in this case. Default: False.
|
|
1244
|
+
|
|
1245
|
+
Returns:
|
|
1246
|
+
Function, returns the Jacobian function for the input function or cell.
|
|
1247
|
+
For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
|
|
1248
|
+
like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
|
|
1249
|
+
|
|
1250
|
+
Raises:
|
|
1251
|
+
TypeError: `grad_position` or `has_aux` does not belong to required types.
|
|
1252
|
+
|
|
1253
|
+
Supported Platforms:
|
|
1254
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1255
|
+
|
|
1256
|
+
Examples:
|
|
1257
|
+
>>> import numpy as np
|
|
1258
|
+
>>> import mindspore.nn as nn
|
|
1259
|
+
>>> from mindspore import jacrev
|
|
1260
|
+
>>> from mindspore import Tensor
|
|
1261
|
+
>>> class MultipleInputsMultipleOutputsNet(nn.Cell):
|
|
1262
|
+
... def construct(self, x, y, z):
|
|
1263
|
+
... return x ** 2 + y ** 2 + z ** 2, x * y * z
|
|
1264
|
+
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
1265
|
+
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
|
1266
|
+
>>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
|
1267
|
+
>>> net = MultipleInputsMultipleOutputsNet()
|
|
1268
|
+
>>> jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z)
|
|
1269
|
+
>>> print(jac)
|
|
1270
|
+
[[[[ 2., 0.]
|
|
1271
|
+
[ 0., 0.]]
|
|
1272
|
+
[[ 0., 4.]
|
|
1273
|
+
[ 0., 0.]]]
|
|
1274
|
+
[[[ 0., 0.]
|
|
1275
|
+
[ 6., 0.]]
|
|
1276
|
+
[[ 0., 0.]
|
|
1277
|
+
[ 0., 8.]]]]
|
|
1278
|
+
>>> print(aux)
|
|
1279
|
+
[[ 1. 4.]
|
|
1280
|
+
[ 9. 16.]]
|
|
1281
|
+
"""
|
|
1282
|
+
_check_has_aux_type(has_aux)
|
|
1283
|
+
|
|
1284
|
+
def aux_fn(*args):
|
|
1285
|
+
outputs = fn(*args)
|
|
1286
|
+
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
1287
|
+
raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
|
|
1288
|
+
res = outputs[0]
|
|
1289
|
+
return res
|
|
1290
|
+
|
|
1291
|
+
@jit
|
|
1292
|
+
def wrapped(*args):
|
|
1293
|
+
checked_grad_position = _check_grad_position(grad_position, len(args))
|
|
1294
|
+
outputs = fn(*args)
|
|
1295
|
+
primals, v, outputs_shape = _jacrev_construct_v(args, outputs, has_aux)
|
|
1296
|
+
|
|
1297
|
+
def inner_fn(vjp_inputs, vectors):
|
|
1298
|
+
gradient_outputs = _grad(fn, None, checked_grad_position)(*vjp_inputs, vectors)
|
|
1299
|
+
return gradient_outputs
|
|
1300
|
+
|
|
1301
|
+
def inner_aux_fn(vjp_inputs, vectors):
|
|
1302
|
+
gradient_outputs = _grad(aux_fn, None, checked_grad_position)(*vjp_inputs, vectors)
|
|
1303
|
+
return gradient_outputs
|
|
1304
|
+
|
|
1305
|
+
if has_aux:
|
|
1306
|
+
res = _vmap(inner_aux_fn)(primals, v)
|
|
1307
|
+
jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
|
|
1308
|
+
forward_outputs = fn(*args)
|
|
1309
|
+
if len(forward_outputs) == 2:
|
|
1310
|
+
return jac_res, forward_outputs[1]
|
|
1311
|
+
return jac_res, forward_outputs[1:]
|
|
1312
|
+
|
|
1313
|
+
res = _vmap(inner_fn)(primals, v)
|
|
1314
|
+
jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
|
|
1315
|
+
return jac_res
|
|
1316
|
+
|
|
1317
|
+
return wrapped
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
def custom_vjp(fn=None):
|
|
1321
|
+
"""
|
|
1322
|
+
Support vjp to custom bprop for function.
|
|
1323
|
+
|
|
1324
|
+
Args:
|
|
1325
|
+
fn (function): The `fn` that need to define custom bprop. Default: None.
|
|
1326
|
+
|
|
1327
|
+
Supported Platforms:
|
|
1328
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1329
|
+
"""
|
|
1330
|
+
|
|
1331
|
+
def deco(fn):
|
|
1332
|
+
class CustomVjp(Cell):
|
|
1333
|
+
"""
|
|
1334
|
+
The CustomVjp decorates function into cell to support custom bprop.
|
|
1335
|
+
"""
|
|
1336
|
+
|
|
1337
|
+
def __init__(self, fwd):
|
|
1338
|
+
super(CustomVjp, self).__init__()
|
|
1339
|
+
self.fwd = fwd
|
|
1340
|
+
self.bwd = None
|
|
1341
|
+
self.add_flags(custom_vjp=True)
|
|
1342
|
+
|
|
1343
|
+
def construct(self, *args):
|
|
1344
|
+
return self.fwd(*args)
|
|
1345
|
+
|
|
1346
|
+
def defbwd(self, bwd):
|
|
1347
|
+
self.bwd = bwd
|
|
1348
|
+
|
|
1349
|
+
def bprop(self, *args):
|
|
1350
|
+
return self.bwd(*args)
|
|
1351
|
+
|
|
1352
|
+
return CustomVjp(fn)
|
|
1353
|
+
|
|
1354
|
+
if fn is not None:
|
|
1355
|
+
return deco(fn)
|
|
1356
|
+
return deco
|
|
1357
|
+
|
|
1358
|
+
|
|
1359
|
+
def stop_gradient(value):
|
|
1360
|
+
"""
|
|
1361
|
+
StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
|
|
1362
|
+
the gradient propagation from an output of a function.
|
|
1363
|
+
For more details, please refer to `Stop Gradient
|
|
1364
|
+
<https://www.mindspore.cn/tutorials/en/r2.0/beginner/autograd.html#stop-gradient>`_.
|
|
1365
|
+
|
|
1366
|
+
Args:
|
|
1367
|
+
value (Any): The value whose effect on the gradient to be eliminated.
|
|
1368
|
+
|
|
1369
|
+
Returns:
|
|
1370
|
+
The same as `value`.
|
|
1371
|
+
|
|
1372
|
+
Supported Platforms:
|
|
1373
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1374
|
+
|
|
1375
|
+
Examples:
|
|
1376
|
+
>>> import mindspore.ops as ops
|
|
1377
|
+
>>> from mindspore import Tensor
|
|
1378
|
+
>>> from mindspore import dtype as mstype
|
|
1379
|
+
>>> def net(x, y):
|
|
1380
|
+
... out1 = ops.MatMul()(x, y)
|
|
1381
|
+
... out2 = ops.MatMul()(x, y)
|
|
1382
|
+
... out2 = ops.stop_gradient(out2)
|
|
1383
|
+
... return out1, out2
|
|
1384
|
+
...
|
|
1385
|
+
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
|
1386
|
+
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
|
1387
|
+
>>> grad_fn = ops.grad(net)
|
|
1388
|
+
>>> output = grad_fn(x, y)
|
|
1389
|
+
>>> print(output)
|
|
1390
|
+
[[1.4100001 1.6 6.5999994]
|
|
1391
|
+
[1.4100001 1.6 6.5999994]]
|
|
1392
|
+
"""
|
|
1393
|
+
return P.StopGradient()(value)
|
|
727
1394
|
|
|
728
1395
|
|
|
729
1396
|
__all__ = [
|
|
730
1397
|
'grad',
|
|
731
1398
|
'value_and_grad',
|
|
1399
|
+
'jacfwd',
|
|
1400
|
+
'jacrev',
|
|
732
1401
|
'jet',
|
|
733
1402
|
'derivative',
|
|
734
1403
|
'jvp',
|
|
735
1404
|
'vjp',
|
|
736
|
-
'linearize'
|
|
1405
|
+
'linearize',
|
|
1406
|
+
'stop_gradient',
|
|
1407
|
+
'get_grad'
|
|
737
1408
|
]
|
|
738
1409
|
__all__.sort()
|