mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/ops/composite/base.py
CHANGED
|
@@ -18,24 +18,24 @@
|
|
|
18
18
|
"""Basic composite operations."""
|
|
19
19
|
from __future__ import absolute_import
|
|
20
20
|
from functools import partial
|
|
21
|
+
|
|
21
22
|
from types import FunctionType, MethodType
|
|
22
23
|
import mindspore as ms
|
|
23
|
-
import mindspore.nn as nn
|
|
24
24
|
from mindspore import context
|
|
25
25
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
26
|
-
from mindspore import
|
|
27
|
-
from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_,
|
|
26
|
+
from mindspore.parallel._utils import _grads_divided_by_device_num_if_recomputation
|
|
27
|
+
from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
|
28
28
|
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
|
29
29
|
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
|
30
|
-
ListClear_, ListReverse_, ListExtend_,
|
|
30
|
+
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
|
31
|
+
ZerosLike_
|
|
31
32
|
from mindspore.common import dtype as mstype
|
|
32
|
-
from mindspore.common.api import
|
|
33
|
+
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
34
|
+
from mindspore.common.api import _add_flags, _core
|
|
33
35
|
from mindspore.ops.primitive import Primitive
|
|
34
|
-
from mindspore.ops.operations import _grad_ops
|
|
35
|
-
from mindspore.ops import operations as P
|
|
36
36
|
from mindspore.ops import signature as sig
|
|
37
37
|
|
|
38
|
-
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
|
|
38
|
+
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_, ZerosLike_]
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def add_flags(fn=None, **flags):
|
|
@@ -59,18 +59,7 @@ def add_flags(fn=None, **flags):
|
|
|
59
59
|
True
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
|
-
|
|
63
|
-
# need set the attr and access on c++
|
|
64
|
-
if not hasattr(fn, "_func_graph_flags"):
|
|
65
|
-
fn._func_graph_flags = {}
|
|
66
|
-
|
|
67
|
-
fn._func_graph_flags.update({**flags})
|
|
68
|
-
return fn
|
|
69
|
-
|
|
70
|
-
ret = deco
|
|
71
|
-
if fn is not None:
|
|
72
|
-
ret = deco(fn)
|
|
73
|
-
return ret
|
|
62
|
+
return _add_flags(fn, **flags)
|
|
74
63
|
|
|
75
64
|
|
|
76
65
|
def core(fn=None, **flags):
|
|
@@ -81,8 +70,8 @@ def core(fn=None, **flags):
|
|
|
81
70
|
set flag to a graph.
|
|
82
71
|
|
|
83
72
|
Args:
|
|
84
|
-
fn (Function): Function to add flag. Default: None.
|
|
85
|
-
flags (dict): The following flags can be set core, which indicates that this is a core function or
|
|
73
|
+
fn (Function, optional): Function to add flag. Default: None.
|
|
74
|
+
flags (dict, optional): The following flags can be set core, which indicates that this is a core function or
|
|
86
75
|
other flag. Default: None.
|
|
87
76
|
|
|
88
77
|
Supported Platforms:
|
|
@@ -95,31 +84,18 @@ def core(fn=None, **flags):
|
|
|
95
84
|
True
|
|
96
85
|
"""
|
|
97
86
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def deco(fn):
|
|
101
|
-
fn._func_graph_flags = {
|
|
102
|
-
'core': True,
|
|
103
|
-
**flags,
|
|
104
|
-
}
|
|
105
|
-
return fn
|
|
106
|
-
|
|
107
|
-
if fn is not None:
|
|
108
|
-
ret = deco(fn)
|
|
109
|
-
else:
|
|
110
|
-
ret = deco
|
|
111
|
-
return ret
|
|
87
|
+
return _core(fn, **flags)
|
|
112
88
|
|
|
113
89
|
|
|
114
90
|
def _get_grad_weights_id(weights=None):
|
|
115
91
|
"""generate id of parameters"""
|
|
116
92
|
res = ""
|
|
117
93
|
if isinstance(weights, Parameter):
|
|
118
|
-
res = weights.name
|
|
94
|
+
res = weights.name + str(weights.requires_grad)
|
|
119
95
|
if isinstance(weights, ParameterTuple):
|
|
120
|
-
res = ''.join(item.name for item in weights)
|
|
96
|
+
res = ''.join(item.name + str(item.requires_grad) for item in weights)
|
|
121
97
|
if isinstance(weights, list):
|
|
122
|
-
res = ''.join(item.name for item in weights if isinstance(item, Parameter))
|
|
98
|
+
res = ''.join(item.name + str(item.requires_grad) for item in weights if isinstance(item, Parameter))
|
|
123
99
|
return res
|
|
124
100
|
|
|
125
101
|
|
|
@@ -130,83 +106,85 @@ class GradOperation(GradOperation_):
|
|
|
130
106
|
The gradient function generated by `GradOperation` higher-order function can be customized by
|
|
131
107
|
construction arguments.
|
|
132
108
|
|
|
133
|
-
|
|
109
|
+
For example, given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
|
|
134
110
|
see `Net` in Examples.
|
|
135
111
|
|
|
112
|
+
- Used to get the derivative of the input:
|
|
136
113
|
|
|
137
|
-
|
|
138
|
-
(see `GradNetWrtX` in Examples).
|
|
114
|
+
1. Returns gradients with respect to the first input (see `GradNetWrtX` in Examples).
|
|
139
115
|
|
|
140
|
-
|
|
141
|
-
`grad_op = GradOperation()`.
|
|
116
|
+
1) Construct a `GradOperation` higher-order function with default arguments: `grad_op = GradOperation()`.
|
|
142
117
|
|
|
143
|
-
|
|
118
|
+
2) Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
|
|
144
119
|
|
|
145
|
-
|
|
146
|
-
|
|
120
|
+
3) Call the gradient function with input function's inputs to get the gradients with respect to the first
|
|
121
|
+
input: `grad_op(net)(x, y)`.
|
|
147
122
|
|
|
123
|
+
2. Returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
|
|
148
124
|
|
|
149
|
-
|
|
125
|
+
1) Construct a `GradOperation` higher-order function with `get_all=True` which indicates getting gradients
|
|
126
|
+
with respect to all inputs, they are `x` and `y` in example function `Net()`:
|
|
127
|
+
`grad_op = GradOperation(get_all=True)`.
|
|
150
128
|
|
|
151
|
-
|
|
152
|
-
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
|
|
153
|
-
`grad_op = GradOperation(get_all=True)`.
|
|
129
|
+
2) Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
|
|
154
130
|
|
|
155
|
-
|
|
131
|
+
3) Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
|
|
132
|
+
`gradient_function(x, y)`.
|
|
156
133
|
|
|
157
|
-
|
|
158
|
-
`gradient_function(x, y)`.
|
|
134
|
+
- Used to get the derivative of the parameters:
|
|
159
135
|
|
|
160
|
-
|
|
161
|
-
(see `GradNetWithWrtParams` in Examples).
|
|
136
|
+
Returns gradients with respect to given parameters (see `GradNetWithWrtParams` in Examples).
|
|
162
137
|
|
|
163
|
-
|
|
164
|
-
|
|
138
|
+
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
|
|
139
|
+
`grad_op = GradOperation(get_by_list=True)`.
|
|
165
140
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
141
|
+
2. Construct a `ParameterTuple` that will be passed to the input function when constructing
|
|
142
|
+
`GradOperation` higher-order function, it will be used as a parameter filter that determine
|
|
143
|
+
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
|
|
169
144
|
|
|
170
|
-
|
|
171
|
-
|
|
145
|
+
3. Call it with input function and `params` as arguments to get the gradient function:
|
|
146
|
+
`gradient_function = grad_op(net, params)`.
|
|
172
147
|
|
|
173
|
-
|
|
174
|
-
|
|
148
|
+
4. Call the gradient function with input function's inputs to get the gradients with
|
|
149
|
+
respect to given parameters: `gradient_function(x, y)`.
|
|
175
150
|
|
|
176
|
-
|
|
177
|
-
|
|
151
|
+
- Used to get the derivative of the inputs and parameters at the same time:
|
|
152
|
+
Returns gradients with respect to all inputs and given parameters in the format of ((dx, dy), (dz))
|
|
153
|
+
(see `GradNetWrtInputsAndParams` in Examples).
|
|
178
154
|
|
|
179
|
-
|
|
180
|
-
|
|
155
|
+
1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
|
|
156
|
+
`grad_op = GradOperation(get_all=True, get_by_list=True)`.
|
|
181
157
|
|
|
182
|
-
|
|
183
|
-
|
|
158
|
+
2. Construct a `ParameterTuple` that will be passed along input function when constructing
|
|
159
|
+
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
|
|
184
160
|
|
|
185
|
-
|
|
186
|
-
|
|
161
|
+
3. Call it with input function and `params` as arguments to get the gradient function:
|
|
162
|
+
`gradient_function = grad_op(net, params)`.
|
|
187
163
|
|
|
188
|
-
|
|
189
|
-
|
|
164
|
+
4. Call the gradient function with input function's inputs to get the gradients with respect to
|
|
165
|
+
all inputs and given parameters: `gradient_function(x, y)`.
|
|
190
166
|
|
|
191
|
-
|
|
192
|
-
|
|
167
|
+
- We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
|
|
168
|
+
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
|
|
169
|
+
same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
|
193
170
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
|
171
|
+
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
|
|
172
|
+
`grad_op = GradOperation(get_all=True, sens_param=True)`.
|
|
197
173
|
|
|
198
|
-
|
|
199
|
-
|
|
174
|
+
2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
|
|
175
|
+
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
|
200
176
|
|
|
201
|
-
|
|
202
|
-
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
|
177
|
+
3. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
|
|
203
178
|
|
|
204
|
-
|
|
205
|
-
|
|
179
|
+
4. Call the gradient function with input function's inputs and `sens_param` to
|
|
180
|
+
get the gradients with respect to all inputs: `gradient_function(x, y, grad_wrt_output)`.
|
|
206
181
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
182
|
+
Note:
|
|
183
|
+
For above gradient functions, the returned gradient result may vary for grad result element number:
|
|
184
|
+
|
|
185
|
+
- Return a single value if only one result.
|
|
186
|
+
- Return a tuple for multiple results.
|
|
187
|
+
- Return an empty tuple for no result.
|
|
210
188
|
|
|
211
189
|
Args:
|
|
212
190
|
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
|
@@ -351,7 +329,7 @@ class GradOperation(GradOperation_):
|
|
|
351
329
|
self.get_all = get_all
|
|
352
330
|
self.get_by_list = get_by_list
|
|
353
331
|
self.sens_param = sens_param
|
|
354
|
-
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False)
|
|
332
|
+
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False)
|
|
355
333
|
self.grad_fn = None
|
|
356
334
|
self.fn = None
|
|
357
335
|
self.weights_id = None
|
|
@@ -363,34 +341,35 @@ class GradOperation(GradOperation_):
|
|
|
363
341
|
if self.grad_fn is not None and self.fn == fn and self.weights_id == weights_id:
|
|
364
342
|
return self.grad_fn
|
|
365
343
|
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
|
|
366
|
-
# If calling Grad in GRAPH_MODE or calling Grad in
|
|
344
|
+
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
367
345
|
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
|
368
346
|
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
|
369
|
-
# In PYNATIVE_MODE calling Grad from
|
|
347
|
+
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
|
348
|
+
# grad in GRAPH_MODE.
|
|
370
349
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
371
350
|
dynamic_shape_inputs = None
|
|
372
351
|
if isinstance(fn, ms.nn.Cell):
|
|
373
352
|
dynamic_shape_inputs = fn.get_inputs()
|
|
374
353
|
fn.grad_ops_label = True
|
|
375
354
|
if self.get_by_list:
|
|
376
|
-
@
|
|
377
|
-
def after_grad(*args):
|
|
378
|
-
return grad_(fn, weights)(*args)
|
|
355
|
+
@jit(input_signature=dynamic_shape_inputs)
|
|
356
|
+
def after_grad(*args, **kwargs):
|
|
357
|
+
return grad_(fn, weights)(*args, **kwargs)
|
|
379
358
|
else:
|
|
380
|
-
@
|
|
381
|
-
def after_grad(*args):
|
|
382
|
-
return grad_(fn)(*args)
|
|
359
|
+
@jit(input_signature=dynamic_shape_inputs)
|
|
360
|
+
def after_grad(*args, **kwargs):
|
|
361
|
+
return grad_(fn)(*args, **kwargs)
|
|
383
362
|
elif self.pynative_:
|
|
384
363
|
@_wrap_func
|
|
385
364
|
def after_grad(*args, **kwargs):
|
|
386
|
-
self._pynative_forward_run(fn, grad_, args, kwargs)
|
|
365
|
+
self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
387
366
|
_pynative_executor.grad(fn, grad_, weights, self.grad_position, *args, **kwargs)
|
|
388
|
-
out = _pynative_executor(
|
|
389
|
-
|
|
367
|
+
out = _pynative_executor()
|
|
368
|
+
out = _grads_divided_by_device_num_if_recomputation(out)
|
|
390
369
|
return out
|
|
391
370
|
else:
|
|
392
371
|
grad_.pynative_ = True
|
|
393
|
-
# after_grad of this branch can't use @
|
|
372
|
+
# after_grad of this branch can't use @jit, just directly call grad_
|
|
394
373
|
if self.get_by_list:
|
|
395
374
|
def after_grad(*args, **kwargs):
|
|
396
375
|
return grad_(fn, weights)(*args, **kwargs)
|
|
@@ -403,7 +382,7 @@ class GradOperation(GradOperation_):
|
|
|
403
382
|
self.weights_id = weights_id
|
|
404
383
|
return self.grad_fn
|
|
405
384
|
|
|
406
|
-
def _pynative_forward_run(self, fn, grad, args, kwargs):
|
|
385
|
+
def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
|
|
407
386
|
""" Pynative forward run to build grad graph. """
|
|
408
387
|
new_kwargs = kwargs
|
|
409
388
|
if self.sens_param:
|
|
@@ -413,14 +392,14 @@ class GradOperation(GradOperation_):
|
|
|
413
392
|
new_kwargs = kwargs.copy()
|
|
414
393
|
new_kwargs.pop('sens')
|
|
415
394
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
416
|
-
if not _pynative_executor.check_run(grad, fn,
|
|
395
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *args, **new_kwargs):
|
|
417
396
|
_pynative_executor.set_grad_flag(True)
|
|
418
397
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
|
419
398
|
output = fn(*args, **new_kwargs)
|
|
420
399
|
_pynative_executor.end_graph(fn, output, *args, **new_kwargs)
|
|
421
400
|
else:
|
|
422
401
|
# Check if fn have run already
|
|
423
|
-
if not _pynative_executor.check_run(grad, fn,
|
|
402
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *args, **new_kwargs):
|
|
424
403
|
fn.set_grad()
|
|
425
404
|
fn(*args, **new_kwargs)
|
|
426
405
|
fn.set_grad(False)
|
|
@@ -442,9 +421,9 @@ class _TaylorOperation(TaylorOperation_):
|
|
|
442
421
|
return self.grad_fn
|
|
443
422
|
taylor_grad_ = _TaylorOperation()
|
|
444
423
|
|
|
445
|
-
# If calling Grad in GRAPH_MODE or calling Grad in
|
|
424
|
+
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
446
425
|
|
|
447
|
-
@
|
|
426
|
+
@jit
|
|
448
427
|
def after_taylor_grad(*args):
|
|
449
428
|
return taylor_grad_(fn)(*args)
|
|
450
429
|
|
|
@@ -453,12 +432,77 @@ class _TaylorOperation(TaylorOperation_):
|
|
|
453
432
|
return self.grad_fn
|
|
454
433
|
|
|
455
434
|
|
|
435
|
+
def _combine_weight(grad_position, weights, out, out_with_ids):
|
|
436
|
+
""" Making resulting tuple for weight, when return_ids is set to True. """
|
|
437
|
+
weight_tuple = []
|
|
438
|
+
position = 0
|
|
439
|
+
if isinstance(weights, (list, ParameterTuple, tuple)) and grad_position:
|
|
440
|
+
for weight in weights:
|
|
441
|
+
weight_tuple.append((weight.name, out[1][position]))
|
|
442
|
+
position += 1
|
|
443
|
+
elif isinstance(weights, (list, ParameterTuple, tuple)):
|
|
444
|
+
for weight in weights:
|
|
445
|
+
weight_tuple.append((weight.name, out[position]))
|
|
446
|
+
position += 1
|
|
447
|
+
elif grad_position:
|
|
448
|
+
weight_tuple.append(weights.name)
|
|
449
|
+
weight_tuple.append(out[1])
|
|
450
|
+
else:
|
|
451
|
+
weight_tuple.append(weights.name)
|
|
452
|
+
weight_tuple.append(out)
|
|
453
|
+
if grad_position:
|
|
454
|
+
out_with_ids.append(tuple(weight_tuple))
|
|
455
|
+
else:
|
|
456
|
+
out_with_ids = weight_tuple
|
|
457
|
+
return out_with_ids
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _combine_position(grad_position, weights, out, out_with_ids):
|
|
461
|
+
""" Making resulting tuple for position, when return_ids is set to True. """
|
|
462
|
+
position_tuple = []
|
|
463
|
+
position = 0
|
|
464
|
+
if grad_position == (0,) and weights is not None:
|
|
465
|
+
position_tuple.append(0)
|
|
466
|
+
position_tuple.append(out[0])
|
|
467
|
+
elif grad_position == (0,):
|
|
468
|
+
position_tuple.append(0)
|
|
469
|
+
position_tuple.append(out)
|
|
470
|
+
elif weights is not None:
|
|
471
|
+
for index in grad_position:
|
|
472
|
+
position_tuple.append((index, out[0][position]))
|
|
473
|
+
position += 1
|
|
474
|
+
else:
|
|
475
|
+
for index in grad_position:
|
|
476
|
+
position_tuple.append((index, out[position]))
|
|
477
|
+
position += 1
|
|
478
|
+
if weights:
|
|
479
|
+
out_with_ids.append(tuple(position_tuple))
|
|
480
|
+
else:
|
|
481
|
+
out_with_ids = position_tuple
|
|
482
|
+
return out_with_ids
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _combine_with_ids(grad_position, weights, out):
|
|
486
|
+
""" Making resulting tuple, when return_ids is set to True. """
|
|
487
|
+
out_with_ids = []
|
|
488
|
+
if grad_position:
|
|
489
|
+
out_with_ids = _combine_position(
|
|
490
|
+
grad_position, weights, out, out_with_ids)
|
|
491
|
+
if weights is not None:
|
|
492
|
+
out_with_ids = _combine_weight(
|
|
493
|
+
grad_position, weights, out, out_with_ids)
|
|
494
|
+
if not out_with_ids:
|
|
495
|
+
raise ValueError(f"output tuple should not be a empty tuple.")
|
|
496
|
+
return tuple(out_with_ids)
|
|
497
|
+
|
|
498
|
+
|
|
456
499
|
class _Grad(GradOperation_):
|
|
457
500
|
"""
|
|
458
501
|
A higher-order function which is used to generate the gradient function by position for the input function.
|
|
459
502
|
"""
|
|
460
503
|
|
|
461
|
-
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False, get_value=False
|
|
504
|
+
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False, get_value=False,
|
|
505
|
+
return_ids=False):
|
|
462
506
|
"""Initialize _Grad."""
|
|
463
507
|
if not isinstance(get_by_position, bool):
|
|
464
508
|
raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
|
|
@@ -475,18 +519,22 @@ class _Grad(GradOperation_):
|
|
|
475
519
|
if not isinstance(get_value, bool):
|
|
476
520
|
raise TypeError(f"For '_Grad', the 'get_value' should be bool, "
|
|
477
521
|
f"but got {type(get_value).__name__}")
|
|
522
|
+
if not isinstance(return_ids, bool):
|
|
523
|
+
raise TypeError(f"For '_Grad', the 'return_ids' should be bool, "
|
|
524
|
+
f"but got {type(return_ids).__name__}")
|
|
478
525
|
self.get_by_position = get_by_position
|
|
479
526
|
self.get_by_list = get_by_list
|
|
480
527
|
self.sens_param = sens_param
|
|
481
528
|
self.has_aux = has_aux
|
|
482
529
|
self.get_value = get_value
|
|
483
|
-
|
|
530
|
+
self.return_ids = return_ids
|
|
531
|
+
GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position, has_aux, get_value,
|
|
532
|
+
return_ids)
|
|
484
533
|
self.grad_fn = None
|
|
485
534
|
self.fn = None
|
|
486
535
|
self.pynative_ = False
|
|
487
536
|
self.grad_position = None
|
|
488
537
|
self.weights_id = None
|
|
489
|
-
self.grad_hash_id = None
|
|
490
538
|
|
|
491
539
|
def __call__(self, fn, weights=None, grad_position=0):
|
|
492
540
|
weights_id = _get_grad_weights_id(weights)
|
|
@@ -499,41 +547,44 @@ class _Grad(GradOperation_):
|
|
|
499
547
|
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
500
548
|
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
|
501
549
|
res = (outputs[0],)
|
|
502
|
-
stop_gradient = Primitive("
|
|
550
|
+
stop_gradient = Primitive("StopGradient")
|
|
503
551
|
for item in outputs[1:]:
|
|
504
552
|
res += (stop_gradient(item),)
|
|
505
553
|
return res
|
|
506
554
|
|
|
507
|
-
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value
|
|
508
|
-
|
|
555
|
+
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux, self.get_value,
|
|
556
|
+
self.return_ids)
|
|
557
|
+
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
509
558
|
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
|
510
559
|
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
|
511
|
-
# In PYNATIVE_MODE calling Grad from
|
|
560
|
+
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
|
561
|
+
# grad in GRAPH_MODE.
|
|
512
562
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
513
563
|
dynamic_shape_inputs = None
|
|
514
564
|
if isinstance(fn, ms.nn.Cell):
|
|
515
565
|
dynamic_shape_inputs = fn.get_inputs()
|
|
516
566
|
if self.get_by_position:
|
|
517
|
-
@
|
|
567
|
+
@jit(input_signature=dynamic_shape_inputs)
|
|
518
568
|
def after_grad(*args):
|
|
519
569
|
return grad_(fn, weights, grad_position)(*args)
|
|
520
570
|
else:
|
|
521
571
|
if self.get_by_list:
|
|
522
|
-
@
|
|
572
|
+
@jit(input_signature=dynamic_shape_inputs)
|
|
523
573
|
def after_grad(*args):
|
|
524
574
|
return grad_(fn, weights)(*args)
|
|
525
575
|
else:
|
|
526
|
-
@
|
|
576
|
+
@jit(input_signature=dynamic_shape_inputs)
|
|
527
577
|
def after_grad(*args):
|
|
528
578
|
return grad_(fn)(*args)
|
|
529
579
|
elif self.pynative_:
|
|
530
|
-
|
|
531
580
|
@_wrap_func
|
|
532
581
|
def after_grad(*args, **kwargs):
|
|
533
|
-
res = self._pynative_forward_run(fn, grad_, args, kwargs)
|
|
582
|
+
res = self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
534
583
|
_pynative_executor.grad(fn, grad_, weights, grad_position, *args, **kwargs)
|
|
535
|
-
out = _pynative_executor(
|
|
536
|
-
|
|
584
|
+
out = _pynative_executor()
|
|
585
|
+
out = _grads_divided_by_device_num_if_recomputation(out)
|
|
586
|
+
if self.return_ids and out:
|
|
587
|
+
out = _combine_with_ids(grad_position, weights, out)
|
|
537
588
|
if self.get_value:
|
|
538
589
|
return res, out
|
|
539
590
|
if self.has_aux:
|
|
@@ -544,7 +595,7 @@ class _Grad(GradOperation_):
|
|
|
544
595
|
fn_ = fn
|
|
545
596
|
if self.has_aux:
|
|
546
597
|
fn_ = aux_fn
|
|
547
|
-
# after_grad of this branch can't use @
|
|
598
|
+
# after_grad of this branch can't use @jit, just directly call grad_
|
|
548
599
|
if self.get_by_position:
|
|
549
600
|
def after_grad(*args, **kwargs):
|
|
550
601
|
return grad_(fn_, weights, grad_position)(*args, **kwargs)
|
|
@@ -560,10 +611,9 @@ class _Grad(GradOperation_):
|
|
|
560
611
|
self.fn = fn
|
|
561
612
|
self.grad_position = grad_position
|
|
562
613
|
self.weights_id = weights_id
|
|
563
|
-
self.grad_hash_id = (grad_position, weights_id)
|
|
564
614
|
return self.grad_fn
|
|
565
615
|
|
|
566
|
-
def _pynative_forward_run(self, fn, grad, args, kwargs):
|
|
616
|
+
def _pynative_forward_run(self, fn, grad, weights, args, kwargs):
|
|
567
617
|
""" Pynative forward runs to build grad graph. """
|
|
568
618
|
new_kwargs = kwargs
|
|
569
619
|
outputs = ()
|
|
@@ -574,7 +624,7 @@ class _Grad(GradOperation_):
|
|
|
574
624
|
else:
|
|
575
625
|
args = args[:-1]
|
|
576
626
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
577
|
-
if not _pynative_executor.check_run(grad, fn, self.
|
|
627
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *args, **new_kwargs):
|
|
578
628
|
_pynative_executor.set_grad_flag(True)
|
|
579
629
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
|
580
630
|
outputs = fn(*args, **new_kwargs)
|
|
@@ -582,7 +632,7 @@ class _Grad(GradOperation_):
|
|
|
582
632
|
return outputs
|
|
583
633
|
else:
|
|
584
634
|
# Check if fn has run already.
|
|
585
|
-
if not _pynative_executor.check_run(grad, fn, self.
|
|
635
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *args, **new_kwargs):
|
|
586
636
|
fn.set_grad()
|
|
587
637
|
outputs = fn(*args, **new_kwargs)
|
|
588
638
|
fn.set_grad(False)
|
|
@@ -602,23 +652,28 @@ class _Vmap(VmapOperation_):
|
|
|
602
652
|
VmapOperation_.__init__(self, 'vmap')
|
|
603
653
|
self.vmap_fn = None
|
|
604
654
|
self.fn = None
|
|
655
|
+
self.in_axes = None
|
|
656
|
+
self.out_axes = None
|
|
605
657
|
|
|
606
658
|
def __call__(self, fn, in_axes=0, out_axes=0):
|
|
659
|
+
if self.vmap_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes:
|
|
660
|
+
return self.vmap_fn
|
|
661
|
+
|
|
607
662
|
vmap_ = self
|
|
608
663
|
|
|
609
|
-
@
|
|
610
|
-
def after_vmap(*args):
|
|
611
|
-
return vmap_(fn, in_axes, out_axes)(*args)
|
|
664
|
+
@jit
|
|
665
|
+
def after_vmap(*args, **kwargs):
|
|
666
|
+
return vmap_(fn, in_axes, out_axes)(*args, **kwargs)
|
|
612
667
|
|
|
613
668
|
self.vmap_fn = after_vmap
|
|
614
669
|
self.fn = fn
|
|
670
|
+
self.in_axes = in_axes
|
|
671
|
+
self.out_axes = out_axes
|
|
615
672
|
return self.vmap_fn
|
|
616
673
|
|
|
617
674
|
|
|
618
675
|
class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
619
676
|
"""
|
|
620
|
-
Generates overloaded functions.
|
|
621
|
-
|
|
622
677
|
MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
|
|
623
678
|
Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
|
|
624
679
|
for the function to be registered. And the object can be called with different types of inputs,
|
|
@@ -626,8 +681,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
626
681
|
|
|
627
682
|
Args:
|
|
628
683
|
name (str): Operator name.
|
|
629
|
-
read_value (bool): If the registered function not need to set value on Parameter,
|
|
684
|
+
read_value (bool, optional): If the registered function do not need to set value on Parameter,
|
|
630
685
|
and all inputs will pass by value, set `read_value` to True. Default: False.
|
|
686
|
+
doc_url (str, optional): The official document link corresponding to the registered function. Default:"".
|
|
631
687
|
|
|
632
688
|
Raises:
|
|
633
689
|
ValueError: If failed to find a matching function for the given arguments.
|
|
@@ -641,10 +697,10 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
641
697
|
>>> from mindspore import Tensor
|
|
642
698
|
>>> from mindspore import ops
|
|
643
699
|
>>> from mindspore import dtype as mstype
|
|
644
|
-
>>>
|
|
700
|
+
>>> import mindspore.ops as ops
|
|
645
701
|
>>>
|
|
646
702
|
>>> tensor_add = ops.Add()
|
|
647
|
-
>>> add = MultitypeFuncGraph('add')
|
|
703
|
+
>>> add = ops.MultitypeFuncGraph('add')
|
|
648
704
|
>>> @add.register("Number", "Number")
|
|
649
705
|
... def add_scala(x, y):
|
|
650
706
|
... return x + y
|
|
@@ -659,9 +715,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
659
715
|
[0.2 1.2 2.4]
|
|
660
716
|
"""
|
|
661
717
|
|
|
662
|
-
def __init__(self, name, read_value=False):
|
|
718
|
+
def __init__(self, name, read_value=False, doc_url=""):
|
|
663
719
|
"""Initialize MultitypeFuncGraph."""
|
|
664
|
-
MultitypeFuncGraph_.__init__(self, name)
|
|
720
|
+
MultitypeFuncGraph_.__init__(self, name, doc_url)
|
|
665
721
|
self.entries = list()
|
|
666
722
|
if read_value:
|
|
667
723
|
self.set_signatures((
|
|
@@ -726,10 +782,11 @@ class HyperMap(HyperMap_):
|
|
|
726
782
|
Only supported in graph mode. Default is False.
|
|
727
783
|
|
|
728
784
|
Inputs:
|
|
729
|
-
- **args** (Tuple[sequence]) -
|
|
730
|
-
And each row of the sequences will be the inputs of the operation.
|
|
785
|
+
- **args** (Tuple[sequence]) -
|
|
731
786
|
|
|
732
|
-
If `ops` is `None`, the
|
|
787
|
+
- If `ops` is not `None`, all the inputs should be sequences with the same length.
|
|
788
|
+
And each row of the sequences will be the inputs of the operation.
|
|
789
|
+
- If `ops` is `None`, the first input is the operation, and the others are inputs.
|
|
733
790
|
|
|
734
791
|
Note:
|
|
735
792
|
Except for the operation input, the number of inputs should be equal to the number of inputs to `ops`.
|
|
@@ -747,23 +804,22 @@ class HyperMap(HyperMap_):
|
|
|
747
804
|
|
|
748
805
|
Examples:
|
|
749
806
|
>>> from mindspore import Tensor, ops
|
|
750
|
-
>>> from mindspore.ops.composite.base import MultitypeFuncGraph, HyperMap
|
|
751
807
|
>>> from mindspore import dtype as mstype
|
|
752
808
|
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
|
|
753
809
|
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
|
|
754
810
|
>>> # square all the tensor in the nested list
|
|
755
811
|
>>>
|
|
756
|
-
>>> square = MultitypeFuncGraph('square')
|
|
812
|
+
>>> square = ops.MultitypeFuncGraph('square')
|
|
757
813
|
>>> @square.register("Tensor")
|
|
758
814
|
... def square_tensor(x):
|
|
759
815
|
... return ops.square(x)
|
|
760
816
|
>>>
|
|
761
|
-
>>> common_map = HyperMap()
|
|
817
|
+
>>> common_map = ops.HyperMap()
|
|
762
818
|
>>> output = common_map(square, nest_tensor_list)
|
|
763
819
|
>>> print(output)
|
|
764
820
|
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
|
|
765
821
|
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
|
|
766
|
-
>>> square_map = HyperMap(square, False)
|
|
822
|
+
>>> square_map = ops.HyperMap(square, False)
|
|
767
823
|
>>> output = square_map(nest_tensor_list)
|
|
768
824
|
>>> print(output)
|
|
769
825
|
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
|
|
@@ -859,100 +915,6 @@ class Map(Map_):
|
|
|
859
915
|
return tuple(map(func, *args_list))
|
|
860
916
|
|
|
861
917
|
|
|
862
|
-
class Shard(Shard_):
|
|
863
|
-
"""Shard operation"""
|
|
864
|
-
|
|
865
|
-
def __init__(self):
|
|
866
|
-
"""Initialize Shard."""
|
|
867
|
-
Shard_.__init__(self, 'Shard')
|
|
868
|
-
self.shard_fn = None
|
|
869
|
-
self.fn = None
|
|
870
|
-
self.in_strategy = None
|
|
871
|
-
self.out_strategy = None
|
|
872
|
-
self.parameter_plan = None
|
|
873
|
-
self.device = None
|
|
874
|
-
self.level = None
|
|
875
|
-
|
|
876
|
-
def __call__(self, fn, in_strategy, out_strategy, parameter_plan=None, device="Ascend", level=0):
|
|
877
|
-
if context.get_context("mode") != context.PYNATIVE_MODE or \
|
|
878
|
-
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
|
|
879
|
-
raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode")
|
|
880
|
-
if context.get_context("device_target") not in ["Ascend"]:
|
|
881
|
-
raise AssertionError(f"'Shard' now only supports 'Ascend'")
|
|
882
|
-
if context.get_auto_parallel_context("full_batch"):
|
|
883
|
-
raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False")
|
|
884
|
-
if context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
|
885
|
-
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
|
|
886
|
-
if not isinstance(in_strategy, tuple):
|
|
887
|
-
raise TypeError(f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
|
|
888
|
-
if not isinstance(out_strategy, tuple):
|
|
889
|
-
raise TypeError(f"For 'Shard', the 'out_strategy' should be a tuple, "
|
|
890
|
-
f"but got {type(out_strategy).__name__}")
|
|
891
|
-
if not isinstance(parameter_plan, (dict, type(None))):
|
|
892
|
-
raise TypeError(f"For 'Shard', the 'parameter_plan' should be a dict or None, "
|
|
893
|
-
f"but got {type(parameter_plan).__name__}")
|
|
894
|
-
if isinstance(parameter_plan, dict):
|
|
895
|
-
for k in parameter_plan.keys():
|
|
896
|
-
v = parameter_plan[k]
|
|
897
|
-
if not isinstance(k, str) or not isinstance(v, tuple):
|
|
898
|
-
raise TypeError(f"For 'Shard', the type of each key and value in 'parameter_plan' must be str and "
|
|
899
|
-
f"tuple, but got {type(k).__name__} and {type(parameter_plan[v]).__name__}")
|
|
900
|
-
parameter_plan = self._parameter_plan_dict2tuple(parameter_plan)
|
|
901
|
-
|
|
902
|
-
if not isinstance(device, str):
|
|
903
|
-
raise TypeError(f"For 'Shard', the 'device' should be a string, "
|
|
904
|
-
f"but got {type(device).__name__}")
|
|
905
|
-
if not isinstance(level, int):
|
|
906
|
-
raise TypeError(f"For 'Shard', the 'level' should be an integer, "
|
|
907
|
-
f"but got {type(level).__name__}")
|
|
908
|
-
|
|
909
|
-
if ms.get_algo_parameters("fully_use_devices") is True:
|
|
910
|
-
logger.warning("After calling 'shard', the environment variable 'fully_use_devices' "
|
|
911
|
-
"will be overwritten as False")
|
|
912
|
-
ms.set_algo_parameters(fully_use_devices=False)
|
|
913
|
-
|
|
914
|
-
if self._is_attrs_has_been_set(fn, in_strategy, out_strategy, parameter_plan, device, level):
|
|
915
|
-
return self.shard_fn
|
|
916
|
-
shard_ = Shard()
|
|
917
|
-
|
|
918
|
-
if isinstance(fn, nn.Cell):
|
|
919
|
-
for param in fn.trainable_params():
|
|
920
|
-
param.is_in_shard = True
|
|
921
|
-
|
|
922
|
-
def shard_fn(*args):
|
|
923
|
-
args = (fn,) + args
|
|
924
|
-
|
|
925
|
-
@ms_function(hash_args=fn)
|
|
926
|
-
def after_shard(*args):
|
|
927
|
-
return shard_(fn, in_strategy, out_strategy, parameter_plan, device, level)(*args)
|
|
928
|
-
|
|
929
|
-
return after_shard(*args)
|
|
930
|
-
|
|
931
|
-
self.shard_fn = shard_fn
|
|
932
|
-
self.fn = fn
|
|
933
|
-
self.in_strategy = in_strategy
|
|
934
|
-
self.out_strategy = out_strategy
|
|
935
|
-
self.parameter_plan = parameter_plan
|
|
936
|
-
self.device = device
|
|
937
|
-
self.level = level
|
|
938
|
-
return self.shard_fn
|
|
939
|
-
|
|
940
|
-
@staticmethod
|
|
941
|
-
def _parameter_plan_dict2tuple(parameter_plan):
|
|
942
|
-
if not isinstance(parameter_plan, dict):
|
|
943
|
-
return parameter_plan
|
|
944
|
-
|
|
945
|
-
parameter_plan_tuple = ()
|
|
946
|
-
for k in parameter_plan:
|
|
947
|
-
parameter_plan_tuple += ((k, parameter_plan[k]),)
|
|
948
|
-
return parameter_plan_tuple
|
|
949
|
-
|
|
950
|
-
def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, parameter_plan, device, level):
|
|
951
|
-
return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
|
|
952
|
-
self.out_strategy == out_strategy and self.parameter_plan == parameter_plan and \
|
|
953
|
-
self.device == device and self.level == level
|
|
954
|
-
|
|
955
|
-
|
|
956
918
|
class _ListAppend(ListAppend_):
|
|
957
919
|
"""
|
|
958
920
|
A metafuncgraph class that append one element to list.
|
|
@@ -1067,23 +1029,80 @@ class _ListExtend(ListExtend_):
|
|
|
1067
1029
|
_extend = _ListExtend("extend")
|
|
1068
1030
|
|
|
1069
1031
|
|
|
1070
|
-
class
|
|
1032
|
+
class _DictClear(DictClear_):
|
|
1071
1033
|
"""
|
|
1072
|
-
A metafuncgraph class that
|
|
1034
|
+
A metafuncgraph class that clear the dict.
|
|
1073
1035
|
|
|
1074
1036
|
Args:
|
|
1075
1037
|
name (str): The name of the metafuncgraph object.
|
|
1076
1038
|
"""
|
|
1077
1039
|
|
|
1078
1040
|
def __init__(self, name):
|
|
1079
|
-
"""Initialize
|
|
1080
|
-
|
|
1041
|
+
"""Initialize _DictClear."""
|
|
1042
|
+
DictClear_.__init__(self, name)
|
|
1081
1043
|
|
|
1082
1044
|
def __call__(self, *args):
|
|
1083
1045
|
pass
|
|
1084
1046
|
|
|
1085
1047
|
|
|
1086
|
-
|
|
1048
|
+
_dict_clear = _DictClear("clear")
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
class _DictHasKey(DictHasKey_):
|
|
1052
|
+
"""
|
|
1053
|
+
A metafuncgraph class that Check if key is in dict.
|
|
1054
|
+
|
|
1055
|
+
Args:
|
|
1056
|
+
name (str): The name of the metafuncgraph object.
|
|
1057
|
+
"""
|
|
1058
|
+
|
|
1059
|
+
def __init__(self, name):
|
|
1060
|
+
"""Initialize _DictHasKey."""
|
|
1061
|
+
DictHasKey_.__init__(self, name)
|
|
1062
|
+
|
|
1063
|
+
def __call__(self, *args):
|
|
1064
|
+
pass
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
_haskey = _DictHasKey("has_key")
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
class _DictUpdate(DictUpdate_):
|
|
1071
|
+
"""
|
|
1072
|
+
A metafuncgraph class that append another dict to the end of the dict.
|
|
1073
|
+
|
|
1074
|
+
Args:
|
|
1075
|
+
name (str): The name of the metafuncgraph object.
|
|
1076
|
+
"""
|
|
1077
|
+
|
|
1078
|
+
def __init__(self, name):
|
|
1079
|
+
"""Initialize _DictUpdate."""
|
|
1080
|
+
DictUpdate_.__init__(self, name)
|
|
1081
|
+
|
|
1082
|
+
def __call__(self, *args):
|
|
1083
|
+
pass
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
_update = _DictUpdate("update")
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class _DictFromKeys(DictFromKeys_):
|
|
1090
|
+
"""
|
|
1091
|
+
A metafuncgraph class that creates a new dict from the given sequence and value.
|
|
1092
|
+
|
|
1093
|
+
Args:
|
|
1094
|
+
name (str): The name of the metafuncgraph object.
|
|
1095
|
+
"""
|
|
1096
|
+
|
|
1097
|
+
def __init__(self, name):
|
|
1098
|
+
"""Initialize _DictFromKeys."""
|
|
1099
|
+
DictFromKeys_.__init__(self, name)
|
|
1100
|
+
|
|
1101
|
+
def __call__(self, *args):
|
|
1102
|
+
pass
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
_fromkeys = _DictFromKeys("fromkeys")
|
|
1087
1106
|
|
|
1088
1107
|
|
|
1089
1108
|
class _Tail(Tail_):
|
|
@@ -1118,15 +1137,3 @@ class _ZipOperation(ZipOperation_):
|
|
|
1118
1137
|
|
|
1119
1138
|
zip_operation = _ZipOperation('zip_operation')
|
|
1120
1139
|
"""`zip_operation` will generate a tuple of zip iterations of inputs."""
|
|
1121
|
-
|
|
1122
|
-
env_get = MultitypeFuncGraph("env_get")
|
|
1123
|
-
|
|
1124
|
-
environ_get = Primitive('EnvironGet')
|
|
1125
|
-
ref_to_embed = _grad_ops.RefToEmbed()
|
|
1126
|
-
zeros_like = P.ZerosLike()
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
@env_get.register("EnvType", "Tensor")
|
|
1130
|
-
def _tensor_env_get(env, parameter):
|
|
1131
|
-
"""Used to get env."""
|
|
1132
|
-
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))
|