mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2978 @@
|
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Operators for gradients."""
|
|
17
|
+
# pylint: disable=unused-import
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
|
|
20
|
+
from __future__ import division
|
|
21
|
+
from mindspore._checkparam import _check_3d_int_or_tuple
|
|
22
|
+
from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
|
|
23
|
+
from mindspore.ops import signature as sig
|
|
24
|
+
from mindspore.ops._utils import get_concat_offset
|
|
25
|
+
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
|
26
|
+
import mindspore.context as context
|
|
27
|
+
from mindspore import _checkparam as validator
|
|
28
|
+
from mindspore.common import dtype as mstype
|
|
29
|
+
from mindspore.communication.management import GlobalComm
|
|
30
|
+
from mindspore.common._utils import is_shape_unknown, is_dim_unknown
|
|
31
|
+
from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad, AsinhGrad, ReciprocalGrad, RsqrtGrad,
|
|
32
|
+
SqrtGrad, BatchNormGrad, BatchNormGradGrad, BiasAddGrad, GeLUGrad, FastGeLUGrad,
|
|
33
|
+
AvgPoolGrad, MinimumGrad, LogSoftmaxGrad, PReLUGrad, ReluGrad, ReLU6Grad, EluGrad,
|
|
34
|
+
GatherDGradV2, ResizeBilinearGrad, ResizeLinear1DGrad, ResizeNearestNeighborV2Grad,
|
|
35
|
+
SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
|
|
36
|
+
ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
|
|
37
|
+
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
|
|
38
|
+
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
|
|
39
|
+
BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SparseFillEmptyRowsGrad(Primitive):
|
|
43
|
+
"""Performs grad of SparseFillEmptyRows operation."""
|
|
44
|
+
|
|
45
|
+
@prim_attr_register
|
|
46
|
+
def __init__(self):
|
|
47
|
+
"""Initialize SparseFillEmptyRowsGrad."""
|
|
48
|
+
self.init_prim_io_names(inputs=['reverse_index_map', 'grad_values'],
|
|
49
|
+
outputs=['y_values', 'y_default_value'])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ScaleAndTranslateGrad(Primitive):
|
|
53
|
+
"""Performs grad of ScaleAndTranslate operation."""
|
|
54
|
+
|
|
55
|
+
@prim_attr_register
|
|
56
|
+
def __init__(self, kernel_type="lanczos3", antialias=True):
|
|
57
|
+
"""Initialize ScaleAndTranslateGrad"""
|
|
58
|
+
validator.check_value_type("kernel_type", kernel_type, [str], self.name)
|
|
59
|
+
validator.check_string(kernel_type, ["lanczos1", "lanczos3", "lanczos5", "gaussian", "box", "triangle",
|
|
60
|
+
"keyscubic", "mitchellcubic"], "kernel_type", self.name)
|
|
61
|
+
validator.check_value_type("antialias", antialias, [bool], self.name)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SoftmaxGrad(Primitive):
|
|
65
|
+
"""Performs grad of Softmax operation."""
|
|
66
|
+
|
|
67
|
+
@prim_attr_register
|
|
68
|
+
def __init__(self):
|
|
69
|
+
"""Initialize SoftmaxGrad"""
|
|
70
|
+
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SyncBatchNormGrad(Primitive):
|
|
74
|
+
"""Performs grad of SyncBatchNorm operation."""
|
|
75
|
+
|
|
76
|
+
@prim_attr_register
|
|
77
|
+
def __init__(self, epsilon=1e-5, group="group0", device_num=2):
|
|
78
|
+
validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
|
|
79
|
+
if not isinstance(group, str):
|
|
80
|
+
raise TypeError("The group attr of SyncBatchNormGrad must be str.")
|
|
81
|
+
validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class KLDivLossGrad(Primitive):
|
|
85
|
+
"""Computes gradients for `KLDivLoss` operation."""
|
|
86
|
+
|
|
87
|
+
@prim_attr_register
|
|
88
|
+
def __init__(self, reduction='mean'):
|
|
89
|
+
device_target = context.get_context("device_target")
|
|
90
|
+
if device_target == "CPU":
|
|
91
|
+
support_mode = ['none', 'mean', 'batchmean', 'sum']
|
|
92
|
+
elif device_target == "GPU":
|
|
93
|
+
support_mode = ['none', 'mean', 'sum']
|
|
94
|
+
elif device_target == "Ascend":
|
|
95
|
+
support_mode = ['none', 'mean', 'batchmean', 'sum']
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
|
|
98
|
+
self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class LuUnpackGrad(Primitive):
|
|
102
|
+
"""Computes gradients for `LuUnpack` operation."""
|
|
103
|
+
|
|
104
|
+
@prim_attr_register
|
|
105
|
+
def __init__(self, L_grad_flag, U_grad_flag):
|
|
106
|
+
validator.check_value_type("L_grad_flag", L_grad_flag, [bool], self.name)
|
|
107
|
+
validator.check_value_type("U_grad_flag", U_grad_flag, [bool], self.name)
|
|
108
|
+
self.add_prim_attr("cust_aicpu", self.name)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ConcatOffset(PrimitiveWithInfer):
|
|
112
|
+
"""primitive for computing Concat's gradient."""
|
|
113
|
+
|
|
114
|
+
@prim_attr_register
|
|
115
|
+
def __init__(self, N=2, axis=0):
|
|
116
|
+
"""Initialize ConcatOffset"""
|
|
117
|
+
|
|
118
|
+
def __infer__(self, input_x):
|
|
119
|
+
axis = self.axis
|
|
120
|
+
x_shp = input_x['shape']
|
|
121
|
+
x_type = input_x['dtype']
|
|
122
|
+
self.add_prim_attr('T', x_type[0].element_type())
|
|
123
|
+
|
|
124
|
+
# input_x is dynamic rank
|
|
125
|
+
rank = -1
|
|
126
|
+
is_dyn_rank = False
|
|
127
|
+
for _, sh in enumerate(x_shp):
|
|
128
|
+
if is_dim_unknown(sh):
|
|
129
|
+
is_dyn_rank = True
|
|
130
|
+
else:
|
|
131
|
+
rank = len(sh)
|
|
132
|
+
if is_dyn_rank:
|
|
133
|
+
return {
|
|
134
|
+
'shape': [len(x_shp), rank],
|
|
135
|
+
'dtype': mstype.int64,
|
|
136
|
+
'value': None
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# if the dimension of input_x on the axis is dynamic
|
|
140
|
+
if axis < -rank or axis >= rank:
|
|
141
|
+
raise ValueError("For 'ConcatOffset', 'axis' must be in range [{}, {}), but got {}"
|
|
142
|
+
.format(-rank, rank, axis))
|
|
143
|
+
if axis < 0:
|
|
144
|
+
axis = axis + rank
|
|
145
|
+
for each in x_shp:
|
|
146
|
+
if each[axis] == -1:
|
|
147
|
+
return {
|
|
148
|
+
'shape': [len(x_shp), len(x_shp[0])],
|
|
149
|
+
'dtype': mstype.int64,
|
|
150
|
+
'value': None
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
|
|
154
|
+
offset_values = []
|
|
155
|
+
for i in range(len(x_shp)):
|
|
156
|
+
values = []
|
|
157
|
+
for j in range(len(x_shp[0])):
|
|
158
|
+
value = 0
|
|
159
|
+
if j == axis:
|
|
160
|
+
value = offset[i]
|
|
161
|
+
values.append(value)
|
|
162
|
+
offset_values.append(tuple(values))
|
|
163
|
+
out = {'shape': None,
|
|
164
|
+
'dtype': None,
|
|
165
|
+
'value': tuple(offset_values)}
|
|
166
|
+
return out
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Conv3DBackpropFilter(Primitive):
|
|
170
|
+
"""
|
|
171
|
+
Computes the gradients of convolution 3D with respect to the filter.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
out_channel (int): The dimension of the output.
|
|
175
|
+
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
|
|
176
|
+
mode (int): Modes for different convolutions. Not currently used.
|
|
177
|
+
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
|
178
|
+
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
|
179
|
+
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
|
|
180
|
+
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
|
|
181
|
+
pad[3], pad[4] and pad[5] correspondingly.
|
|
182
|
+
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
|
|
183
|
+
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
|
|
184
|
+
group (int): Splits input into groups. Default: 1.
|
|
185
|
+
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
|
|
186
|
+
|
|
187
|
+
Inputs:
|
|
188
|
+
- **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
|
|
189
|
+
Currently dout data type only support float16 and float32.
|
|
190
|
+
- **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
|
|
191
|
+
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
|
|
192
|
+
and float32.
|
|
193
|
+
- **w_size** (tuple(int)) - A tuple describes the shape of the weight which conforms to the format
|
|
194
|
+
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
|
195
|
+
|
|
196
|
+
Outputs:
|
|
197
|
+
Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
|
|
198
|
+
|
|
199
|
+
Supported Platforms:
|
|
200
|
+
``Ascend``
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
|
|
204
|
+
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
|
|
205
|
+
>>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
|
|
206
|
+
>>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
|
|
207
|
+
>>> output = conv3d_backprop_input(x, dout, F.shape(w))
|
|
208
|
+
>>> print(output.shape)
|
|
209
|
+
(32, 32, 4, 6, 2)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
@prim_attr_register
|
|
213
|
+
def __init__(self,
|
|
214
|
+
out_channel,
|
|
215
|
+
kernel_size,
|
|
216
|
+
mode=1,
|
|
217
|
+
pad_mode="valid",
|
|
218
|
+
pad=0,
|
|
219
|
+
stride=(1, 1, 1, 1, 1),
|
|
220
|
+
dilation=(1, 1, 1, 1, 1),
|
|
221
|
+
group=1,
|
|
222
|
+
data_format="NCDHW"):
|
|
223
|
+
"""Initialize Convolution"""
|
|
224
|
+
self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
|
|
225
|
+
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
|
226
|
+
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
|
|
227
|
+
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
|
|
228
|
+
self.add_prim_attr('strides', self.stride)
|
|
229
|
+
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
|
|
230
|
+
self.add_prim_attr('dilations', self.dilation)
|
|
231
|
+
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
|
232
|
+
if isinstance(pad, int):
|
|
233
|
+
pad = (pad,) * 6
|
|
234
|
+
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
|
|
235
|
+
self.add_prim_attr('pad', pad)
|
|
236
|
+
self.pad_list = pad
|
|
237
|
+
self.add_prim_attr('pad_list', self.pad_list)
|
|
238
|
+
|
|
239
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
240
|
+
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
|
241
|
+
if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
|
242
|
+
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode must be set as 'pad'.")
|
|
243
|
+
if self.pad_mode == 'pad':
|
|
244
|
+
for item in pad:
|
|
245
|
+
validator.check_non_negative_int(item, 'pad item', self.name)
|
|
246
|
+
self.add_prim_attr('pad_mode', self.pad_mode)
|
|
247
|
+
|
|
248
|
+
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
|
249
|
+
self.add_prim_attr('mode', self.mode)
|
|
250
|
+
self.group = validator.check_positive_int(group, 'group', self.name)
|
|
251
|
+
self.add_prim_attr('groups', self.group)
|
|
252
|
+
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
|
253
|
+
self.add_prim_attr('data_format', self.format)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class Conv2DBackpropFilter(Primitive):
|
|
257
|
+
"""
|
|
258
|
+
Computes the gradients of convolution with respect to the filter.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
out_channel (int): The dimensionality of the output space.
|
|
262
|
+
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
|
|
263
|
+
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
|
264
|
+
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
|
265
|
+
top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
|
|
266
|
+
padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
|
|
267
|
+
pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
|
|
268
|
+
mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
|
|
269
|
+
2 deconvolution, 3 depthwise convolution. Default: 1.
|
|
270
|
+
stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
|
|
271
|
+
dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
|
|
272
|
+
group (int): Splits input into groups. Default: 1.
|
|
273
|
+
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
|
|
274
|
+
default is 'NCHW'.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Tensor, the gradients of convolution.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
@prim_attr_register
|
|
281
|
+
def __init__(self,
|
|
282
|
+
out_channel,
|
|
283
|
+
kernel_size,
|
|
284
|
+
pad_mode="valid",
|
|
285
|
+
pad=0,
|
|
286
|
+
pad_list=(0, 0, 0, 0),
|
|
287
|
+
mode=1,
|
|
288
|
+
stride=(1, 1),
|
|
289
|
+
dilation=(1, 1, 1, 1),
|
|
290
|
+
group=1,
|
|
291
|
+
data_format="NCHW"):
|
|
292
|
+
"""Initialize Convolution"""
|
|
293
|
+
self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
|
|
294
|
+
self.out_channel = out_channel
|
|
295
|
+
self.kernel_size = kernel_size
|
|
296
|
+
self.mode = mode
|
|
297
|
+
pad_mode = pad_mode.upper()
|
|
298
|
+
self.add_prim_attr('pad_mode', pad_mode)
|
|
299
|
+
if isinstance(pad, int):
|
|
300
|
+
pad = (pad,) * 4
|
|
301
|
+
else:
|
|
302
|
+
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
|
303
|
+
self.add_prim_attr("pad", pad)
|
|
304
|
+
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
305
|
+
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
306
|
+
raise ValueError("NHWC format only support in GPU target.")
|
|
307
|
+
self.add_prim_attr('data_format', self.format)
|
|
308
|
+
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
|
309
|
+
self.add_prim_attr('stride', self.stride)
|
|
310
|
+
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
|
311
|
+
self.add_prim_attr('dilation', self.dilation)
|
|
312
|
+
self.group = group
|
|
313
|
+
self.add_prim_attr('groups', group)
|
|
314
|
+
if pad_list:
|
|
315
|
+
for x in pad_list:
|
|
316
|
+
if x != -1:
|
|
317
|
+
validator.check_non_negative_int(x, 'element of pad_list', self.name)
|
|
318
|
+
self.pad_list = pad_list
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
|
|
322
|
+
"""
|
|
323
|
+
Returns the gradient of filter for DepthwiseConv2dNative.
|
|
324
|
+
|
|
325
|
+
Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
|
|
326
|
+
|
|
327
|
+
Refer to class DepthwiseConv2dNative for more details.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
channel_multiplier (int): The multiplier for the original output conv.
|
|
331
|
+
kernel_size (int or tuple): The size of the conv kernel.
|
|
332
|
+
mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution,
|
|
333
|
+
2 deconvolution,3 depthwise convolution. Default: 3.
|
|
334
|
+
pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
|
|
335
|
+
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
|
336
|
+
top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
|
|
337
|
+
padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
|
|
338
|
+
pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
|
|
339
|
+
stride (int): The stride to be applied to the convolution filter. Default: 1.
|
|
340
|
+
dilation (int): Specifies the space to use between kernel elements. Default: 1.
|
|
341
|
+
group (int): Splits input into groups. Default: 1.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
@prim_attr_register
|
|
348
|
+
def __init__(self,
|
|
349
|
+
channel_multiplier,
|
|
350
|
+
kernel_size,
|
|
351
|
+
pad_mode="valid",
|
|
352
|
+
pad=0,
|
|
353
|
+
pad_list=(0, 0, 0, 0),
|
|
354
|
+
mode=3,
|
|
355
|
+
stride=1,
|
|
356
|
+
dilation=1,
|
|
357
|
+
group=1):
|
|
358
|
+
"""Initialize Convolution"""
|
|
359
|
+
self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
|
|
360
|
+
self.channel_multiplier = channel_multiplier
|
|
361
|
+
self.kernel_size = kernel_size
|
|
362
|
+
self.mode = mode
|
|
363
|
+
self.pad_mode = pad_mode
|
|
364
|
+
if isinstance(pad, int):
|
|
365
|
+
pad = (pad,) * 4
|
|
366
|
+
else:
|
|
367
|
+
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
|
368
|
+
self.add_prim_attr("pad", pad)
|
|
369
|
+
self.pad_list = pad_list
|
|
370
|
+
self.stride = stride
|
|
371
|
+
self.dilation = dilation
|
|
372
|
+
self.group = group
|
|
373
|
+
self.add_prim_attr('data_format', "NCHW")
|
|
374
|
+
|
|
375
|
+
def __infer__(self, x, w_size, dout):
|
|
376
|
+
w_size_v = w_size['value']
|
|
377
|
+
args = {'x': x['dtype'], 'dout': dout['dtype']}
|
|
378
|
+
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
|
379
|
+
out = {
|
|
380
|
+
'value': None,
|
|
381
|
+
'shape': w_size_v,
|
|
382
|
+
'dtype': dout['dtype'],
|
|
383
|
+
}
|
|
384
|
+
return out
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
|
|
388
|
+
"""
|
|
389
|
+
Returns the gradient of input for DepthwiseConv2dNative.
|
|
390
|
+
|
|
391
|
+
Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
channel_multiplier (int): The multiplier for the original output conv.
|
|
395
|
+
kernel_size (int or tuple): The size of the conv kernel.
|
|
396
|
+
mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
|
|
397
|
+
2 deconvolution,3 depthwise convolution. Default: 3.
|
|
398
|
+
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
|
|
399
|
+
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
|
|
400
|
+
top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
|
|
401
|
+
padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
|
|
402
|
+
pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
|
|
403
|
+
stride (int): The stride to be applied to the convolution filter. Default: 1.
|
|
404
|
+
dilation (int): Specifies the space to use between kernel elements. Default: 1.
|
|
405
|
+
group (int): Splits input into groups. Default: 1.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Tensor, the value is the gradient of input for DepthwiseConv2dNative.
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
@prim_attr_register
|
|
412
|
+
def __init__(self,
|
|
413
|
+
channel_multiplier,
|
|
414
|
+
kernel_size,
|
|
415
|
+
pad_mode="valid",
|
|
416
|
+
pad=0,
|
|
417
|
+
pad_list=(0, 0, 0, 0),
|
|
418
|
+
mode=3,
|
|
419
|
+
stride=1,
|
|
420
|
+
dilation=1,
|
|
421
|
+
group=1):
|
|
422
|
+
"""Initialize Convolution"""
|
|
423
|
+
self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
|
|
424
|
+
self.channel_multiplier = channel_multiplier
|
|
425
|
+
self.kernel_size = kernel_size
|
|
426
|
+
self.mode = mode
|
|
427
|
+
self.pad_mode = pad_mode
|
|
428
|
+
if isinstance(pad, int):
|
|
429
|
+
pad = (pad,) * 4
|
|
430
|
+
else:
|
|
431
|
+
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
|
|
432
|
+
self.add_prim_attr("pad", pad)
|
|
433
|
+
self.pad_list = pad_list
|
|
434
|
+
self.stride = stride
|
|
435
|
+
self.dilation = dilation
|
|
436
|
+
self.group = group
|
|
437
|
+
self.add_prim_attr('data_format', "NCHW")
|
|
438
|
+
|
|
439
|
+
def __infer__(self, x_size, w, dout):
|
|
440
|
+
args = {'w': w['dtype'], 'dout': dout['dtype']}
|
|
441
|
+
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
|
442
|
+
x_size_v = x_size['value']
|
|
443
|
+
out = {
|
|
444
|
+
'value': None,
|
|
445
|
+
'shape': x_size_v,
|
|
446
|
+
'dtype': dout['dtype'],
|
|
447
|
+
}
|
|
448
|
+
return out
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class DropoutGrad(Primitive):
|
|
452
|
+
"""
|
|
453
|
+
The gradient of Dropout. During training, randomly zeroes some of the elements
|
|
454
|
+
of the input tensor with probability.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
|
|
458
|
+
means dropping out 10% of input units. Default: 0.5.
|
|
459
|
+
|
|
460
|
+
Inputs:
|
|
461
|
+
- **shape** (tuple[int]) - The shape of target mask.
|
|
462
|
+
|
|
463
|
+
Outputs:
|
|
464
|
+
Tensor, the value of generated mask for input shape.
|
|
465
|
+
|
|
466
|
+
Examples:
|
|
467
|
+
>>> dropout_grad = ops.DropoutGrad(keep_prob=0.5)
|
|
468
|
+
>>> in = Tensor((20, 16, 50, 50))
|
|
469
|
+
>>> out = dropout_grad(in)
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
@prim_attr_register
|
|
473
|
+
def __init__(self, keep_prob=0.5):
|
|
474
|
+
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, validator.INC_RIGHT, "keep_prob", self.name)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class FlattenGrad(PrimitiveWithInfer):
|
|
478
|
+
"""Performs gradients of Flatten."""
|
|
479
|
+
|
|
480
|
+
@prim_attr_register
|
|
481
|
+
def __init__(self):
|
|
482
|
+
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
|
|
483
|
+
|
|
484
|
+
def __infer__(self, *args):
|
|
485
|
+
out = {
|
|
486
|
+
'value': None,
|
|
487
|
+
'shape': args[1]['value'],
|
|
488
|
+
'dtype': args[0]['dtype'],
|
|
489
|
+
}
|
|
490
|
+
return out
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class InstanceNormGrad(PrimitiveWithInfer):
|
|
494
|
+
"""Gradients of InstanceNorm operation."""
|
|
495
|
+
|
|
496
|
+
@prim_attr_register
|
|
497
|
+
def __init__(self, epsilon=0.0, momentum=0.1):
|
|
498
|
+
self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
|
|
499
|
+
outputs=['dx', 'bn_gamma', 'bn_beta'])
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
class InstanceNormV2Grad(Primitive):
|
|
503
|
+
"""Gradients of InstanceNormV2 operation."""
|
|
504
|
+
|
|
505
|
+
@prim_attr_register
|
|
506
|
+
def __init__(self, is_training=True, epsilon=1e-5):
|
|
507
|
+
self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'mean', 'variance', 'save_mean', 'save_variance'],
|
|
508
|
+
outputs=['pd_x', 'pd_gamma', 'pd_beta'])
|
|
509
|
+
validator.check_is_float(epsilon, 'epsilon', self.name)
|
|
510
|
+
validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
|
|
511
|
+
validator.check_bool(is_training, "is_training", self.name)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class EinsumGrad(PrimitiveWithInfer):
|
|
515
|
+
"""Gradients of Einsum."""
|
|
516
|
+
|
|
517
|
+
@prim_attr_register
|
|
518
|
+
def __init__(self, equation):
|
|
519
|
+
pass
|
|
520
|
+
|
|
521
|
+
def infer_shape(self, x_shapes, dout_shape):
|
|
522
|
+
out_shape = ()
|
|
523
|
+
for dim in x_shapes:
|
|
524
|
+
out_shape += (dim,)
|
|
525
|
+
return out_shape
|
|
526
|
+
|
|
527
|
+
def infer_dtype(self, x_types, dout_shape):
|
|
528
|
+
out_type = ()
|
|
529
|
+
for cur_type in x_types:
|
|
530
|
+
out_type += (cur_type,)
|
|
531
|
+
return out_type
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class UniqueGrad(Primitive):
|
|
535
|
+
"""Gradients of Unique operation."""
|
|
536
|
+
|
|
537
|
+
@prim_attr_register
|
|
538
|
+
def __init__(self):
|
|
539
|
+
self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class BNTrainingReduceGrad(Primitive):
|
|
543
|
+
"""Gradients of FusedBatchNorm operation."""
|
|
544
|
+
|
|
545
|
+
@prim_attr_register
|
|
546
|
+
def __init__(self, epsilon=0.0001, data_format='NCHW'):
|
|
547
|
+
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
548
|
+
_inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
|
|
549
|
+
self.init_prim_io_names(inputs=_inputs, outputs=['y'])
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class BNTrainingUpdateGrad(Primitive):
|
|
553
|
+
"""Gradients of FusedBatchNorm operation."""
|
|
554
|
+
|
|
555
|
+
@prim_attr_register
|
|
556
|
+
def __init__(self, epsilon=0.0001, data_format='NCHW'):
|
|
557
|
+
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
558
|
+
self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
|
|
559
|
+
outputs=['diff_scale', 'diff_offset'])
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
class NeighborExchangeV2Grad(PrimitiveWithInfer):
|
|
563
|
+
""""Gradients of NeighborExchangeV2 operation."""
|
|
564
|
+
|
|
565
|
+
@prim_attr_register
|
|
566
|
+
def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
|
|
567
|
+
group=GlobalComm.WORLD_COMM_GROUP):
|
|
568
|
+
self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
|
|
569
|
+
self.send_rank_ids = send_rank_ids
|
|
570
|
+
self.recv_rank_ids = recv_rank_ids
|
|
571
|
+
self.send_lens = send_lens
|
|
572
|
+
self.recv_lens = recv_lens
|
|
573
|
+
self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name)
|
|
574
|
+
self.add_prim_attr('no_elimilate', True)
|
|
575
|
+
|
|
576
|
+
def __infer__(self, dy):
|
|
577
|
+
dy_shape = dy['shape']
|
|
578
|
+
validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, validator.EQ, self.name)
|
|
579
|
+
if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
|
|
580
|
+
dy_shape[3] -= self.send_lens[2]
|
|
581
|
+
|
|
582
|
+
if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1:
|
|
583
|
+
dy_shape[3] -= self.send_lens[3]
|
|
584
|
+
|
|
585
|
+
if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1:
|
|
586
|
+
dy_shape[2] -= self.send_lens[0]
|
|
587
|
+
|
|
588
|
+
if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1:
|
|
589
|
+
dy_shape[2] -= self.send_lens[1]
|
|
590
|
+
|
|
591
|
+
return {'shape': dy_shape,
|
|
592
|
+
'dtype': dy['dtype'],
|
|
593
|
+
'value': None}
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class _PoolGrad(PrimitiveWithInfer):
|
|
597
|
+
"""Gradients of the max/avg pool operation."""
|
|
598
|
+
|
|
599
|
+
@prim_attr_register
|
|
600
|
+
def __init__(self, kernel_size, strides, pad_mode="VALID", data_format="NCHW"):
|
|
601
|
+
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
|
|
602
|
+
|
|
603
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
604
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
605
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
606
|
+
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
607
|
+
self.add_prim_attr("pad_mode", self.pad_mode)
|
|
608
|
+
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
609
|
+
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
610
|
+
raise ValueError("NHWC format only support in GPU target.")
|
|
611
|
+
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
|
|
612
|
+
if not self.is_maxpoolgradwithargmax:
|
|
613
|
+
self.add_prim_attr('data_format', self.format)
|
|
614
|
+
|
|
615
|
+
def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
|
|
616
|
+
validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
|
|
617
|
+
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
|
|
618
|
+
f"or a tuple of two or four positive int numbers, but got {arg_val}")
|
|
619
|
+
if isinstance(arg_val, int):
|
|
620
|
+
ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
|
|
621
|
+
elif len(arg_val) == 2:
|
|
622
|
+
ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
|
|
623
|
+
elif len(arg_val) == 4:
|
|
624
|
+
ret = arg_val
|
|
625
|
+
else:
|
|
626
|
+
raise error_msg
|
|
627
|
+
# whether all elements of tuple are positive integers
|
|
628
|
+
for item in ret:
|
|
629
|
+
if not isinstance(item, int) or item <= 0:
|
|
630
|
+
raise error_msg
|
|
631
|
+
return ret
|
|
632
|
+
|
|
633
|
+
kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size, self.is_maxpoolgradwithargmax)
|
|
634
|
+
strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
|
|
635
|
+
if self.format == "NCHW":
|
|
636
|
+
self.kernel_size = kernel_size
|
|
637
|
+
self.strides = strides
|
|
638
|
+
else:
|
|
639
|
+
self.kernel_size = [kernel_size[0], kernel_size[2], kernel_size[3], kernel_size[1]]
|
|
640
|
+
self.strides = [strides[0], strides[2], strides[3], strides[1]]
|
|
641
|
+
self.add_prim_attr("kernel_size", self.kernel_size)
|
|
642
|
+
self.add_prim_attr("strides", self.strides)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
class AvgPoolGradVm(_PoolGrad):
|
|
646
|
+
"""Gradients of the avg pool operation for vm."""
|
|
647
|
+
|
|
648
|
+
@prim_attr_register
|
|
649
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
|
|
650
|
+
super(AvgPoolGradVm, self).__init__(kernel_size, strides, pad_mode)
|
|
651
|
+
self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
|
|
652
|
+
|
|
653
|
+
def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
|
|
654
|
+
out = {
|
|
655
|
+
'value': None,
|
|
656
|
+
'shape': tuple(origin_input['value']),
|
|
657
|
+
'dtype': dout['dtype'],
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
return out
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
class AvgPoolGradGe(_PoolGrad):
|
|
664
|
+
"""Gradients of the avg pool operation for ge."""
|
|
665
|
+
|
|
666
|
+
@prim_attr_register
|
|
667
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
|
668
|
+
super(AvgPoolGradGe, self).__init__(kernel_size, strides, pad_mode, data_format)
|
|
669
|
+
|
|
670
|
+
def __infer__(self, origin_input, dout):
|
|
671
|
+
out = {
|
|
672
|
+
'value': None,
|
|
673
|
+
'shape': tuple(origin_input['value']),
|
|
674
|
+
'dtype': dout['dtype'],
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
return out
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
class AvgPoolGradV1(Primitive):
|
|
681
|
+
"""Gradients of the AvgPoolV1 operation."""
|
|
682
|
+
|
|
683
|
+
@prim_attr_register
|
|
684
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
|
685
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
686
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
687
|
+
self.pad_mode = validator.check_string(
|
|
688
|
+
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
689
|
+
self.add_prim_attr("pad_mode", self.pad_mode)
|
|
690
|
+
self.format = validator.check_string(
|
|
691
|
+
data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
692
|
+
self.add_prim_attr('data_format', self.format)
|
|
693
|
+
|
|
694
|
+
def _avgpoolgrad_check_int_or_tuple(argname, argval):
|
|
695
|
+
validator.check_value_type(argname, argval, (int, tuple), self.name)
|
|
696
|
+
errormsg = ValueError(f"For '{self.name}' the '{argname}' should be an positive int number "
|
|
697
|
+
f"or a tuple of two or four positive int numbers, but got {argval}")
|
|
698
|
+
if isinstance(argval, int):
|
|
699
|
+
ret = (1, 1, argval, argval)
|
|
700
|
+
elif len(argval) == 2:
|
|
701
|
+
ret = (1, 1, argval[0], argval[1])
|
|
702
|
+
elif len(argval) == 4:
|
|
703
|
+
ret = argval
|
|
704
|
+
else:
|
|
705
|
+
raise errormsg
|
|
706
|
+
# whether all elements of tuple are positive integers?
|
|
707
|
+
for it in ret:
|
|
708
|
+
if not isinstance(it, int) or it <= 0:
|
|
709
|
+
raise errormsg
|
|
710
|
+
return ret
|
|
711
|
+
|
|
712
|
+
self.kernel_size = _avgpoolgrad_check_int_or_tuple(
|
|
713
|
+
"kernel_size", kernel_size)
|
|
714
|
+
self.strides = _avgpoolgrad_check_int_or_tuple("strides", strides)
|
|
715
|
+
|
|
716
|
+
self.kernel_size_adapt = self.kernel_size if self.format == "NCHW" else (
|
|
717
|
+
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
|
|
718
|
+
self.strides_adapt = self.strides if self.format == "NCHW" else (
|
|
719
|
+
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
|
|
720
|
+
|
|
721
|
+
# If length of some attrs is 4 we regard it as legal, either by using the op directly,
|
|
722
|
+
# or passed from an instance of forward op AvgPoolV1.
|
|
723
|
+
if len(self.kernel_size) == 4:
|
|
724
|
+
self.kernel_size_adapt = self.kernel_size
|
|
725
|
+
if len(self.strides) == 4:
|
|
726
|
+
self.strides_adapt = self.strides
|
|
727
|
+
|
|
728
|
+
self.add_prim_attr("kernel_size", self.kernel_size_adapt)
|
|
729
|
+
self.add_prim_attr("strides", self.strides_adapt)
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
class AdaptiveAvgPool2DGrad(Primitive):
|
|
733
|
+
"""Gradients of the adaptive avg pool 2D operation."""
|
|
734
|
+
|
|
735
|
+
@prim_attr_register
|
|
736
|
+
def __init__(self):
|
|
737
|
+
"""Initialize AdaptiveAvgPool2DGrad"""
|
|
738
|
+
self.init_prim_io_names(inputs=['input_grad', 'orig_input_shape'], outputs=['output_grad'])
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
class AdaptiveAvgPool3DGrad(Primitive):
|
|
742
|
+
"""Performs grad of AdaptiveAvgPool3D operation."""
|
|
743
|
+
@prim_attr_register
|
|
744
|
+
def __init__(self):
|
|
745
|
+
self.init_prim_io_names(inputs=['y_grad', 'orig_input_shape'], outputs=['x_grad'])
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
class AvgPool3DGrad(Primitive):
|
|
749
|
+
"""Gradients of the avg pool3d operation."""
|
|
750
|
+
|
|
751
|
+
@prim_attr_register
|
|
752
|
+
def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False,
|
|
753
|
+
count_include_pad=True, divisor_override=0, data_format="NCDHW", pad_mode="pad"):
|
|
754
|
+
self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output'])
|
|
755
|
+
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name, allow_five=True, ret_five=True)
|
|
756
|
+
self.add_prim_attr('kernel_size', self.kernel_size)
|
|
757
|
+
self.strides = _check_3d_int_or_tuple('strides', strides, self.name, allow_five=True, ret_five=True)
|
|
758
|
+
self.add_prim_attr('strides', self.strides)
|
|
759
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
760
|
+
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
|
|
761
|
+
validator.check_value_type('pads', pads, (int, tuple), self.name)
|
|
762
|
+
if isinstance(pads, int):
|
|
763
|
+
pads = (pads,) * 6
|
|
764
|
+
validator.check_equal_int(len(pads), 6, 'pad size', self.name)
|
|
765
|
+
for item in pads:
|
|
766
|
+
validator.check_non_negative_int(item, 'pad item', self.name)
|
|
767
|
+
self.add_prim_attr('pad_list', pads)
|
|
768
|
+
self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
|
|
769
|
+
self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
|
|
770
|
+
self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name)
|
|
771
|
+
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
class AdaptiveMaxPool2DGrad(Primitive):
|
|
775
|
+
"""Gradients of the adaptive max pool 2D operation."""
|
|
776
|
+
@prim_attr_register
|
|
777
|
+
def __init__(self):
|
|
778
|
+
"""Initialize AdaptiveMaxPool2DGrad"""
|
|
779
|
+
self.init_prim_io_names(inputs=['y_grad', 'x', 'argmax'], outputs=['x_grad'])
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
class MaxPoolGrad(_PoolGrad):
|
|
783
|
+
"""Performs gradients of the max pool operation."""
|
|
784
|
+
|
|
785
|
+
@prim_attr_register
|
|
786
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
|
787
|
+
super(MaxPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
|
|
788
|
+
|
|
789
|
+
def infer_shape(self, x1_shape, x2_shape, grad_shape):
|
|
790
|
+
return x1_shape
|
|
791
|
+
|
|
792
|
+
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
|
|
793
|
+
return x1_dtype
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
class MaxPoolGradV1(Primitive):
|
|
797
|
+
"""Performs gradients of the MaxPoolV1 operation."""
|
|
798
|
+
|
|
799
|
+
@prim_attr_register
|
|
800
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
|
801
|
+
self.init_prim_io_names(
|
|
802
|
+
inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
|
|
803
|
+
|
|
804
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
805
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
806
|
+
self.pad_mode = validator.check_string(
|
|
807
|
+
pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
808
|
+
self.add_prim_attr("pad_mode", self.pad_mode)
|
|
809
|
+
self.format = validator.check_string(
|
|
810
|
+
data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
811
|
+
self.add_prim_attr('data_format', self.format)
|
|
812
|
+
|
|
813
|
+
def _grad_check_int_or_tuple(arg_name, arg_val):
|
|
814
|
+
validator.check_value_type(
|
|
815
|
+
arg_name, arg_val, (int, tuple), self.name)
|
|
816
|
+
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
|
|
817
|
+
f"or a tuple of two or four positive int numbers, but got {arg_val}")
|
|
818
|
+
if isinstance(arg_val, int):
|
|
819
|
+
ret = (1, 1, arg_val, arg_val)
|
|
820
|
+
elif len(arg_val) == 2:
|
|
821
|
+
ret = (1, 1, arg_val[0], arg_val[1])
|
|
822
|
+
elif len(arg_val) == 4:
|
|
823
|
+
ret = arg_val
|
|
824
|
+
else:
|
|
825
|
+
raise error_msg
|
|
826
|
+
# whether all elements of tuple are positive integers
|
|
827
|
+
for item in ret:
|
|
828
|
+
if not isinstance(item, int) or item <= 0:
|
|
829
|
+
raise error_msg
|
|
830
|
+
return ret
|
|
831
|
+
|
|
832
|
+
self.kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
|
|
833
|
+
self.strides = _grad_check_int_or_tuple("strides", strides)
|
|
834
|
+
|
|
835
|
+
kernel_size_adapted = self.kernel_size if self.format == 'NCHW' else (
|
|
836
|
+
self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
|
|
837
|
+
strides_adapted = self.strides if self.format == 'NCHW' else (
|
|
838
|
+
self.strides[0], self.strides[2], self.strides[3], self.strides[1])
|
|
839
|
+
|
|
840
|
+
if len(kernel_size) == 4:
|
|
841
|
+
kernel_size_adapted = kernel_size
|
|
842
|
+
if len(strides) == 4:
|
|
843
|
+
strides_adapted = strides
|
|
844
|
+
|
|
845
|
+
self.add_prim_attr("kernel_size", kernel_size_adapted)
|
|
846
|
+
self.add_prim_attr("strides", strides_adapted)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
class MaxPoolGradGrad(_PoolGrad):
|
|
850
|
+
r"""
|
|
851
|
+
Performs gradients of the MaxPoolGrad operation.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
|
|
855
|
+
is an int number that represents height and width are both kernel_size, or a tuple
|
|
856
|
+
of two int numbers that represent height and width respectively. Default: 1.
|
|
857
|
+
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
|
858
|
+
the height and width of movement are both strides, or a tuple of two int numbers that
|
|
859
|
+
represent height and width of movement respectively. Default: 1.
|
|
860
|
+
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
|
|
861
|
+
Default: "valid".
|
|
862
|
+
|
|
863
|
+
- same: Adopts the way of completion. The height and width of the output will be the same as
|
|
864
|
+
the input. The total number of padding will be calculated in horizontal and vertical
|
|
865
|
+
directions and evenly distributed to top and bottom, left and right if possible.
|
|
866
|
+
Otherwise, the last extra padding will be done from the bottom and the right side.
|
|
867
|
+
|
|
868
|
+
- valid: Adopts the way of discarding. The possible largest height and width of output
|
|
869
|
+
will be returned without padding. Extra pixels will be discarded.
|
|
870
|
+
|
|
871
|
+
Inputs:
|
|
872
|
+
- **origin_input** (Tensor) - Tensor with data format "NCHW".
|
|
873
|
+
For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
|
|
874
|
+
- **origin_output** (Tensor) - Data type same as `origin_input`.
|
|
875
|
+
- **grad** (Tensor) - Data type and shape same as `origin_input`.
|
|
876
|
+
|
|
877
|
+
Outputs:
|
|
878
|
+
Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
|
|
879
|
+
|
|
880
|
+
Raises:
|
|
881
|
+
TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
|
|
882
|
+
TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
|
|
883
|
+
TypeError: If pad_mode is not string.
|
|
884
|
+
ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
|
|
885
|
+
TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
|
|
886
|
+
float16 nor float32.
|
|
887
|
+
ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 4.
|
|
888
|
+
ValueError: If data types of all inputs are not equal.
|
|
889
|
+
ValueError: If the shapes of `origin_input` and `grad` are not equal.
|
|
890
|
+
|
|
891
|
+
Supported Platforms:
|
|
892
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
893
|
+
"""
|
|
894
|
+
|
|
895
|
+
@prim_attr_register
|
|
896
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
|
|
897
|
+
super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)
|
|
898
|
+
|
|
899
|
+
def infer_shape(self, x1_shape, x2_shape, grad_shape):
|
|
900
|
+
return x2_shape
|
|
901
|
+
|
|
902
|
+
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
|
|
903
|
+
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
|
|
904
|
+
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
|
|
905
|
+
return x2_dtype
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
|
|
909
|
+
"""
|
|
910
|
+
helper for get max pool3d grad pads by pad_mode
|
|
911
|
+
"""
|
|
912
|
+
|
|
913
|
+
def get_pad(origin_shape, ksize, stride):
|
|
914
|
+
tail = origin_shape % stride
|
|
915
|
+
pad = (ksize - tail) if tail > 0 else (ksize - stride)
|
|
916
|
+
pad = max(pad, 0)
|
|
917
|
+
pad1 = int(pad / 2)
|
|
918
|
+
pad2 = int(pad / 2) + pad % 2
|
|
919
|
+
return pad1, pad2
|
|
920
|
+
|
|
921
|
+
_, _, d, h, w = input_shape
|
|
922
|
+
_, _, kd, kh, kw = kernel_size
|
|
923
|
+
_, _, strd, strh, strw = strides
|
|
924
|
+
|
|
925
|
+
pads = (0, 0, 0, 0, 0, 0)
|
|
926
|
+
if pad_mode == 'SAME':
|
|
927
|
+
pads_d = get_pad(d, kd, strd)
|
|
928
|
+
pads_h = get_pad(h, kh, strh)
|
|
929
|
+
pads_w = get_pad(w, kw, strw)
|
|
930
|
+
pads = pads_d + pads_h + pads_w
|
|
931
|
+
return pads
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
class MaxPool3DGrad(Primitive):
|
|
935
|
+
"""Gradients of the max pool3d operation."""
|
|
936
|
+
|
|
937
|
+
@prim_attr_register
|
|
938
|
+
def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1),
|
|
939
|
+
pad_mode='VALID', pad_list=0, data_format="NCDHW"):
|
|
940
|
+
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
|
|
941
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
942
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
943
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
944
|
+
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
|
945
|
+
if pad_mode.upper() == 'PAD':
|
|
946
|
+
pad_mode = 'CALCULATED'
|
|
947
|
+
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'CALCULATED'], 'pad_mode', self.name)
|
|
948
|
+
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
|
|
949
|
+
allow_five=True, ret_five=True)
|
|
950
|
+
self.add_prim_attr("kernel_size", self.kernel_size)
|
|
951
|
+
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
|
|
952
|
+
self.add_prim_attr("strides", self.strides)
|
|
953
|
+
validator.check_value_type('pad_list', pad_list, (int, tuple), self.name)
|
|
954
|
+
self.pad_list = pad_list
|
|
955
|
+
if isinstance(self.pad_list, int):
|
|
956
|
+
self.pad_list = (self.pad_list,) * 6
|
|
957
|
+
if len(self.pad_list) == 3:
|
|
958
|
+
self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[3])
|
|
959
|
+
if len(self.pad_list) != 3 and len(self.pad_list) != 6:
|
|
960
|
+
raise ValueError(f"For `maxpool3d` attr 'pad_list' must be an positive int number or a tuple of "
|
|
961
|
+
f"three or six positive int numbers, but got `{len(self.pad_list)}` numbers.")
|
|
962
|
+
if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
|
|
963
|
+
raise ValueError(f"For '{self.name}', when pad_list is not 0, pad_mode must be set as 'pad'.")
|
|
964
|
+
if self.pad_mode == 'CALCULATED':
|
|
965
|
+
for item in self.pad_list:
|
|
966
|
+
validator.check_non_negative_int(item, 'pad_list item', self.name)
|
|
967
|
+
self.add_prim_attr("pad_list", self.pad_list)
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
class MaxPool3DGradGrad(PrimitiveWithInfer):
|
|
971
|
+
r"""Gradients of the max pool3d grad operation.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
|
|
975
|
+
is an int number that represents depth, height and width are both kernel_size, or a tuple
|
|
976
|
+
of two int numbers that represent depth, height and width respectively. Default: 1.
|
|
977
|
+
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
|
978
|
+
the depth, height and width of movement are both strides, or a tuple of two int numbers that
|
|
979
|
+
represent depth, height and width of movement respectively. Default: 1.
|
|
980
|
+
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
|
|
981
|
+
Default: "valid".
|
|
982
|
+
|
|
983
|
+
- same: Adopts the way of completion. The depth, height and width of the output will be the
|
|
984
|
+
same as the input. The total number of padding will be calculated in depth, horizontal and
|
|
985
|
+
vertical directions and evenly distributed to front and back, top and bottom, left and
|
|
986
|
+
right if possible. Otherwise, the last extra padding will be done from the back, the bottom
|
|
987
|
+
and the right side.
|
|
988
|
+
|
|
989
|
+
- valid: Adopts the way of discarding. The possible largest height and width of output
|
|
990
|
+
will be returned without padding. Extra pixels will be discarded.
|
|
991
|
+
|
|
992
|
+
Inputs:
|
|
993
|
+
- **origin_input** (Tensor) - Tensor with data format "NCDHW".
|
|
994
|
+
For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
|
|
995
|
+
- **origin_output** (Tensor) - Data type same as `origin_input`.
|
|
996
|
+
- **grad** (Tensor) - Data type and shape same as `origin_input`.
|
|
997
|
+
|
|
998
|
+
Outputs:
|
|
999
|
+
Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
|
|
1000
|
+
|
|
1001
|
+
Raises:
|
|
1002
|
+
TypeError: If kernel_size is neither int nor a tuple of 3/5 int numbers.
|
|
1003
|
+
TypeError: If strides is neither int nor a tuple of 3/5 int numbers.
|
|
1004
|
+
TypeError: If pad_mode is not string.
|
|
1005
|
+
ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
|
|
1006
|
+
TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
|
|
1007
|
+
float16 nor float32.
|
|
1008
|
+
ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 5.
|
|
1009
|
+
ValueError: If data types of all inputs are not equal.
|
|
1010
|
+
ValueError: If the shapes of `origin_input` and `grad` are not equal.
|
|
1011
|
+
|
|
1012
|
+
Supported Platforms:
|
|
1013
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1014
|
+
"""
|
|
1015
|
+
|
|
1016
|
+
@prim_attr_register
|
|
1017
|
+
def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
|
|
1018
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
1019
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
1020
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
1021
|
+
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
|
1022
|
+
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
1023
|
+
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
|
|
1024
|
+
allow_five=True, ret_five=True)
|
|
1025
|
+
self.add_prim_attr("kernel_size", self.kernel_size)
|
|
1026
|
+
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
|
|
1027
|
+
self.add_prim_attr("strides", self.strides)
|
|
1028
|
+
|
|
1029
|
+
def infer_shape(self, x_shape, y_shape, grad_shape):
|
|
1030
|
+
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
|
|
1031
|
+
validator.check('x_shape', x_shape, 'grad_shape', grad_shape, prim_name=self.name)
|
|
1032
|
+
pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
|
|
1033
|
+
for pad in pad_list:
|
|
1034
|
+
validator.check_non_negative_int(pad, 'element of pad_list', self.name)
|
|
1035
|
+
self.add_prim_attr("pad_list", pad_list)
|
|
1036
|
+
return y_shape
|
|
1037
|
+
|
|
1038
|
+
def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
|
|
1039
|
+
args = {'x_dtype': x_dtype, 'y_dtype': y_dtype}
|
|
1040
|
+
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
|
1041
|
+
validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1042
|
+
return x_dtype
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
class MaxPoolGradWithArgmax(Primitive):
|
|
1046
|
+
"""Computes the gradients of MaxPoolWithArgmax."""
|
|
1047
|
+
@prim_attr_register
|
|
1048
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
|
|
1049
|
+
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
|
|
1050
|
+
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
|
1051
|
+
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
|
1052
|
+
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
|
1053
|
+
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
1054
|
+
self.add_prim_attr("pad_mode", self.pad_mode)
|
|
1055
|
+
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
1056
|
+
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
1057
|
+
raise ValueError("NHWC format only support in GPU target.")
|
|
1058
|
+
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
|
|
1059
|
+
if not self.is_maxpoolgradwithargmax:
|
|
1060
|
+
self.add_prim_attr('data_format', self.format)
|
|
1061
|
+
|
|
1062
|
+
def _grad_check_int_or_tuple(arg_name, arg_val):
|
|
1063
|
+
validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
|
|
1064
|
+
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
|
|
1065
|
+
f"or a tuple of two or four positive int numbers, but got {arg_val}")
|
|
1066
|
+
if isinstance(arg_val, int):
|
|
1067
|
+
ret = (1, arg_val, arg_val, 1)
|
|
1068
|
+
elif len(arg_val) == 2:
|
|
1069
|
+
ret = (1, arg_val[0], arg_val[1], 1)
|
|
1070
|
+
elif len(arg_val) == 4:
|
|
1071
|
+
ret = arg_val
|
|
1072
|
+
else:
|
|
1073
|
+
raise error_msg
|
|
1074
|
+
# whether all elements of tuple are positive integers
|
|
1075
|
+
for item in ret:
|
|
1076
|
+
if not isinstance(item, int) or item <= 0:
|
|
1077
|
+
raise error_msg
|
|
1078
|
+
return ret
|
|
1079
|
+
|
|
1080
|
+
kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
|
|
1081
|
+
self.kernel_size = kernel_size
|
|
1082
|
+
self.add_prim_attr("kernel_size", self.kernel_size)
|
|
1083
|
+
|
|
1084
|
+
strides = _grad_check_int_or_tuple("strides", strides)
|
|
1085
|
+
self.strides = strides
|
|
1086
|
+
self.add_prim_attr("strides", self.strides)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class MaxPoolGradWithArgmaxV2(Primitive):
|
|
1090
|
+
"""Gradients of the MaxPoolWithArgmaxV2 operation."""
|
|
1091
|
+
|
|
1092
|
+
@prim_attr_register
|
|
1093
|
+
def __init__(self, kernel_size, strides=None, pads=0, dilation=(1, 1), ceil_mode=False, argmax_type=mstype.int64):
|
|
1094
|
+
self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['y'])
|
|
1095
|
+
self.kernel_size = _check_positive_int_or_tuple("kernel_size", kernel_size, self.name, allow_four=True,
|
|
1096
|
+
ret_four=True)
|
|
1097
|
+
self.add_prim_attr('kernel_size', self.kernel_size)
|
|
1098
|
+
if strides is None:
|
|
1099
|
+
strides = kernel_size
|
|
1100
|
+
self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=True, ret_four=True)
|
|
1101
|
+
self.add_prim_attr('strides', self.strides)
|
|
1102
|
+
self.pads = _check_positive_int_or_tuple("pads", pads, self.name, allow_four=True, ret_four=True,
|
|
1103
|
+
strict_positive=False)
|
|
1104
|
+
self.add_prim_attr('pads', self.pads)
|
|
1105
|
+
validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
|
|
1106
|
+
self.add_prim_attr('ceil_mode', self.ceil_mode)
|
|
1107
|
+
self.dilation = _check_positive_int_or_tuple("dilation", dilation, self.name, allow_four=True, ret_four=True)
|
|
1108
|
+
self.add_prim_attr('dilation', self.dilation)
|
|
1109
|
+
self.add_prim_attr('argmax_type', self.argmax_type)
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
class MaxPool3DGradWithArgmax(Primitive):
|
|
1113
|
+
"""Gradients of the maxpool3Dwithargmax operation."""
|
|
1114
|
+
|
|
1115
|
+
@prim_attr_register
|
|
1116
|
+
def __init__(self, ksize, strides, pads, dilation=(1, 1, 1), ceil_mode=False, data_format="NCDHW"):
|
|
1117
|
+
self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
|
|
1118
|
+
validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
|
|
1119
|
+
validator.check_value_type('data_format', data_format, str, self.name)
|
|
1120
|
+
self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
|
|
1121
|
+
self.ksize = _check_3d_int_or_tuple("ksize", ksize, self.name, ret_five=False)
|
|
1122
|
+
self.add_prim_attr('ksize', self.ksize)
|
|
1123
|
+
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, ret_five=False)
|
|
1124
|
+
self.add_prim_attr('strides', self.strides)
|
|
1125
|
+
self.pads = _check_3d_int_or_tuple("pads", pads, self.name, greater_zero=False, ret_five=False)
|
|
1126
|
+
self.add_prim_attr('pads', self.pads)
|
|
1127
|
+
self.dilation = _check_3d_int_or_tuple("dilation", dilation, self.name, allow_five=True, ret_five=False)
|
|
1128
|
+
self.add_prim_attr('dilation', self.dilation)
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
class MaxPoolGradGradWithArgmax(_PoolGrad):
|
|
1132
|
+
r"""
|
|
1133
|
+
Computes the gradients of MaxPoolGradWithArgmax.
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
|
|
1137
|
+
is an int number that represents height and width are both kernel_size, or a tuple
|
|
1138
|
+
of two int numbers that represent height and width respectively. Default: 1.
|
|
1139
|
+
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
|
1140
|
+
the height and width of movement are both strides, or a tuple of two int numbers that
|
|
1141
|
+
represent height and width of movement respectively. Default: 1.
|
|
1142
|
+
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
|
|
1143
|
+
Default: "valid".
|
|
1144
|
+
|
|
1145
|
+
- same: Adopts the way of completion. The height and width of the output will be the same as
|
|
1146
|
+
the input. The total number of padding will be calculated in horizontal and vertical
|
|
1147
|
+
directions and evenly distributed to top and bottom, left and right if possible.
|
|
1148
|
+
Otherwise, the last extra padding will be done from the bottom and the right side.
|
|
1149
|
+
|
|
1150
|
+
- valid: Adopts the way of discarding. The possible largest height and width of output
|
|
1151
|
+
will be returned without padding. Extra pixels will be discarded.
|
|
1152
|
+
|
|
1153
|
+
Inputs:
|
|
1154
|
+
- **x** (Tensor) - Tensor with data format "NCHW".
|
|
1155
|
+
For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
|
|
1156
|
+
- **grad** (Tensor) - Data type and shape same as `x`.
|
|
1157
|
+
- **argmax** (Tensor) - Data type must be int32 or int64.
|
|
1158
|
+
|
|
1159
|
+
Outputs:
|
|
1160
|
+
Tensor, with data type same as `x`. Shape same as `argmax`.
|
|
1161
|
+
|
|
1162
|
+
Raises:
|
|
1163
|
+
TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
|
|
1164
|
+
TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
|
|
1165
|
+
TypeError: If pad_mode is not string.
|
|
1166
|
+
ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
|
|
1167
|
+
TypeError: For Ascend, the data types of `x` and `grad` are not float16.
|
|
1168
|
+
For CPU or GPU, the data types of `x` and `grad` are neither float16 nor float32.
|
|
1169
|
+
TypeError: The data type of `argmax` is neither int32 nor int64.
|
|
1170
|
+
ValueError: If the rank of `x`, `grad` or `argmax` is not equal to 4.
|
|
1171
|
+
ValueError: If the shapes of `x` and `grad` are not equal.
|
|
1172
|
+
|
|
1173
|
+
Supported Platforms:
|
|
1174
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1175
|
+
"""
|
|
1176
|
+
|
|
1177
|
+
@prim_attr_register
|
|
1178
|
+
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
|
|
1179
|
+
self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
|
|
1180
|
+
super(MaxPoolGradGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
|
|
1181
|
+
|
|
1182
|
+
def infer_shape(self, x_shape, grad_shape, argmax_shape):
|
|
1183
|
+
if not grad_shape:
|
|
1184
|
+
raise TypeError("The dout of MaxPoolGradGradWithArgmax must be a Tensor.")
|
|
1185
|
+
return x_shape
|
|
1186
|
+
|
|
1187
|
+
def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
|
|
1188
|
+
args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
|
|
1189
|
+
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
|
|
1190
|
+
return grad_dtype
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
class MinimumGradGrad(Primitive):
|
|
1194
|
+
"""Grad for minimum_grad."""
|
|
1195
|
+
@prim_attr_register
|
|
1196
|
+
def __init__(self):
|
|
1197
|
+
"""Initialize MinimumGradGrad"""
|
|
1198
|
+
super().__init__("MinimumGradGrad")
|
|
1199
|
+
self.init_prim_io_names(inputs=['x1', 'x2', 'grad_y1', 'grad_y2'],
|
|
1200
|
+
outputs=['sopd_x1', 'sopd_x2', 'sopd_grads'])
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
class L2NormalizeGrad(Primitive):
|
|
1204
|
+
r"""
|
|
1205
|
+
Gradients of L2 normalize.
|
|
1206
|
+
|
|
1207
|
+
Args:
|
|
1208
|
+
axis (Union[list(int), tuple(int), int]): The begin axis for the input to apply L2 normalize. Default: 0.
|
|
1209
|
+
epsilon (float): A small value added for numerical stability. Default: 1e-4.
|
|
1210
|
+
|
|
1211
|
+
Inputs:
|
|
1212
|
+
- **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
|
|
1213
|
+
- **out** (Tensor) - Must be the output of forward operator L2Normalize.
|
|
1214
|
+
- **dout** (Tensor) - The backprop of the next layer.
|
|
1215
|
+
|
|
1216
|
+
Outputs:
|
|
1217
|
+
Tensor, gradients of L2Normalize `input_x`.
|
|
1218
|
+
"""
|
|
1219
|
+
|
|
1220
|
+
@prim_attr_register
|
|
1221
|
+
def __init__(self, axis=0, epsilon=1e-4):
|
|
1222
|
+
axis = [axis] if isinstance(axis, int) else axis
|
|
1223
|
+
validator.check_value_type('axis', axis, [list, tuple], self.name)
|
|
1224
|
+
validator.check_value_type('epsilon', epsilon, [int, float], self.name)
|
|
1225
|
+
self.add_prim_attr('axis', axis)
|
|
1226
|
+
self.init_attrs['axis'] = axis
|
|
1227
|
+
if len(axis) != 1:
|
|
1228
|
+
raise TypeError("The length of axis must be 1, later will support multiple axis!")
|
|
1229
|
+
|
|
1230
|
+
|
|
1231
|
+
class LSTMGradData(Primitive):
|
|
1232
|
+
"""Computes the data gradients of LSTM."""
|
|
1233
|
+
|
|
1234
|
+
@prim_attr_register
|
|
1235
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1236
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1237
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1238
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1239
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1240
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1241
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1242
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1243
|
+
|
|
1244
|
+
if bidirectional:
|
|
1245
|
+
self.num_directions = 2
|
|
1246
|
+
else:
|
|
1247
|
+
self.num_directions = 1
|
|
1248
|
+
|
|
1249
|
+
|
|
1250
|
+
class LSTMGradWeight(Primitive):
|
|
1251
|
+
"""Computes the weight gradients of LSTM."""
|
|
1252
|
+
|
|
1253
|
+
@prim_attr_register
|
|
1254
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1255
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1256
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1257
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1258
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1259
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1260
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1261
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1262
|
+
|
|
1263
|
+
if bidirectional:
|
|
1264
|
+
self.num_directions = 2
|
|
1265
|
+
else:
|
|
1266
|
+
self.num_directions = 1
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
class LSTMGrad(Primitive):
|
|
1270
|
+
"""Computes the data and weight gradients of LSTM."""
|
|
1271
|
+
|
|
1272
|
+
@prim_attr_register
|
|
1273
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
|
|
1274
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1275
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1276
|
+
self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
|
|
1277
|
+
'proj_size', self.name)
|
|
1278
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1279
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1280
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1281
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1282
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1283
|
+
|
|
1284
|
+
if bidirectional:
|
|
1285
|
+
self.num_directions = 2
|
|
1286
|
+
else:
|
|
1287
|
+
self.num_directions = 1
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
class DynamicRNNGrad(Primitive):
|
|
1291
|
+
"""Computes the input gradients of DynamicRNN."""
|
|
1292
|
+
|
|
1293
|
+
@prim_attr_register
|
|
1294
|
+
def __init__(self,
|
|
1295
|
+
cell_type='LSTM',
|
|
1296
|
+
direction='UNIDIRECTIONAL',
|
|
1297
|
+
cell_depth=1,
|
|
1298
|
+
use_peephole=False,
|
|
1299
|
+
keep_prob=1.0,
|
|
1300
|
+
cell_clip=-1.0,
|
|
1301
|
+
num_proj=0,
|
|
1302
|
+
time_major=True,
|
|
1303
|
+
forget_bias=0.0):
|
|
1304
|
+
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
class GruGradData(PrimitiveWithInfer):
|
|
1308
|
+
"""Computes the data gradients of GRU."""
|
|
1309
|
+
|
|
1310
|
+
@prim_attr_register
|
|
1311
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1312
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1313
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1314
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1315
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1316
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1317
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1318
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1319
|
+
|
|
1320
|
+
if bidirectional:
|
|
1321
|
+
self.num_directions = 2
|
|
1322
|
+
else:
|
|
1323
|
+
self.num_directions = 1
|
|
1324
|
+
|
|
1325
|
+
def infer_shape(self, y_shape, dy_shape, dhy_shape, w_shape,
|
|
1326
|
+
hx_shape, reserve_shape, state_shape):
|
|
1327
|
+
# dhy and dcy should be same shape
|
|
1328
|
+
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
|
|
1329
|
+
|
|
1330
|
+
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, validator.EQ, "h_shape[0]", self.name)
|
|
1331
|
+
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
|
|
1332
|
+
|
|
1333
|
+
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
|
|
1334
|
+
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
|
|
1335
|
+
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, validator.EQ, "dy[2]", self.name)
|
|
1336
|
+
|
|
1337
|
+
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
|
1338
|
+
dhx_shape = dhy_shape
|
|
1339
|
+
|
|
1340
|
+
return (dx_shape, dhx_shape)
|
|
1341
|
+
|
|
1342
|
+
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, w_dtype,
|
|
1343
|
+
hx_dtype, reserve_dtype, state_dtype):
|
|
1344
|
+
args = {"dy": dy_dtype, "dhy": dhy_dtype}
|
|
1345
|
+
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
|
|
1346
|
+
return (dy_dtype, dy_dtype)
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
class GruGradWeight(PrimitiveWithInfer):
|
|
1350
|
+
"""Computes the weight gradients of GRU."""
|
|
1351
|
+
|
|
1352
|
+
@prim_attr_register
|
|
1353
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1354
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1355
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1356
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1357
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1358
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1359
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1360
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1361
|
+
|
|
1362
|
+
if bidirectional:
|
|
1363
|
+
self.num_directions = 2
|
|
1364
|
+
else:
|
|
1365
|
+
self.num_directions = 1
|
|
1366
|
+
|
|
1367
|
+
def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
|
|
1368
|
+
weight_size = 0
|
|
1369
|
+
gate_size = 3 * self.hidden_size
|
|
1370
|
+
for layer in range(self.num_layers):
|
|
1371
|
+
for _ in range(self.num_directions):
|
|
1372
|
+
input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
|
|
1373
|
+
weight_size += gate_size * input_layer_size
|
|
1374
|
+
weight_size += gate_size * self.hidden_size
|
|
1375
|
+
if self.has_bias:
|
|
1376
|
+
weight_size += 2 * gate_size
|
|
1377
|
+
|
|
1378
|
+
return (weight_size, 1, 1)
|
|
1379
|
+
|
|
1380
|
+
def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
|
|
1381
|
+
return hx_dtype
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
class GRUV2Grad(Primitive):
|
|
1385
|
+
"""Computes the grad gradients of GRU."""
|
|
1386
|
+
|
|
1387
|
+
@prim_attr_register
|
|
1388
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1389
|
+
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1390
|
+
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1391
|
+
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1392
|
+
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1393
|
+
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
1394
|
+
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
|
1395
|
+
self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
|
|
1396
|
+
|
|
1397
|
+
if bidirectional:
|
|
1398
|
+
self.num_directions = 2
|
|
1399
|
+
else:
|
|
1400
|
+
self.num_directions = 1
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
class DynamicGRUV2Grad(Primitive):
|
|
1404
|
+
r"""
|
|
1405
|
+
Computes the input gradients of DynamicGRUV2.
|
|
1406
|
+
|
|
1407
|
+
Args:
|
|
1408
|
+
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
|
1409
|
+
Only 'UNIDIRECTIONAL' is currently supported.
|
|
1410
|
+
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
|
1411
|
+
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
|
1412
|
+
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
|
1413
|
+
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
|
1414
|
+
time_major (bool): A bool identifying the time major in the op. Default: ``True``.
|
|
1415
|
+
gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
|
|
1416
|
+
'zrh' is another option.
|
|
1417
|
+
reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication.
|
|
1418
|
+
Default: ``True``.
|
|
1419
|
+
|
|
1420
|
+
Inputs:
|
|
1421
|
+
- **x** (Tensor) - Current words. Tensor of shape :math:`(num\_step, batch\_size, input\_size)`.
|
|
1422
|
+
The data type must be float16 or float32.
|
|
1423
|
+
- **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input\_size, 3 x hidden\_size)`.
|
|
1424
|
+
The data type must be float16 or float32.
|
|
1425
|
+
- **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden\_size, 3 x hidden\_size)`.
|
|
1426
|
+
The data type must be float16 or float32.
|
|
1427
|
+
- **y** (Tensor) - A Tensor of shape :math:
|
|
1428
|
+
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
|
1429
|
+
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
|
1430
|
+
The data type must be float16 or float32.
|
|
1431
|
+
- **init_h** (Tensor) - Hidden state of initial time.
|
|
1432
|
+
Tensor of shape :math:`(batch\_size, hidden\_size)`.
|
|
1433
|
+
The data type must be float16 or float32.
|
|
1434
|
+
- **h** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1435
|
+
The data type must be float16 or float32.
|
|
1436
|
+
- **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
|
|
1437
|
+
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
|
|
1438
|
+
- **update** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1439
|
+
The data type must be float16 or float32.
|
|
1440
|
+
- **reset** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1441
|
+
The data type must be float16 or float32.
|
|
1442
|
+
- **new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1443
|
+
The data type must be float16 or float32.
|
|
1444
|
+
- **hidden_new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1445
|
+
The data type must be float16 or float32.
|
|
1446
|
+
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch\_size)`.
|
|
1447
|
+
Only `None` is currently supported.
|
|
1448
|
+
- **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
|
|
1449
|
+
|
|
1450
|
+
Outputs:
|
|
1451
|
+
- **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
|
|
1452
|
+
Has the same type with input `x`.
|
|
1453
|
+
- **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
|
|
1454
|
+
Has the same type with input `x`.
|
|
1455
|
+
- **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
|
|
1456
|
+
Has the same type with input `init\_h`.
|
|
1457
|
+
- **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
|
|
1458
|
+
Has the same type with input `init\_h`.
|
|
1459
|
+
- **dx** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
|
|
1460
|
+
Has the same type with input `x`.
|
|
1461
|
+
- **dh_prev** (Tensor) - A Tensor of shape :math:`(batch\_size, hidden\_size)`.
|
|
1462
|
+
Has the same type with input `init\_h`.
|
|
1463
|
+
"""
|
|
1464
|
+
|
|
1465
|
+
@prim_attr_register
|
|
1466
|
+
def __init__(self,
|
|
1467
|
+
direction='UNIDIRECTIONAL',
|
|
1468
|
+
cell_depth=1,
|
|
1469
|
+
keep_prob=1.0,
|
|
1470
|
+
cell_clip=-1.0,
|
|
1471
|
+
num_proj=0,
|
|
1472
|
+
time_major=True,
|
|
1473
|
+
gate_order="rzh",
|
|
1474
|
+
reset_after=True):
|
|
1475
|
+
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
|
1476
|
+
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
|
1477
|
+
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
|
1478
|
+
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
|
1479
|
+
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
|
1480
|
+
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
|
1481
|
+
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
|
|
1482
|
+
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
|
1483
|
+
self.init_prim_io_names(inputs=[
|
|
1484
|
+
"x", "weight_input", "weight_hidden", "y", "init_h", "h", "dy",
|
|
1485
|
+
"dh", "update", "reset", "new", "hidden_new", "seq_length", "mask"
|
|
1486
|
+
],
|
|
1487
|
+
outputs=[
|
|
1488
|
+
"dw_input", "dw_hidden", "db_input",
|
|
1489
|
+
"db_hidden", "dx", "dh_prev"
|
|
1490
|
+
])
|
|
1491
|
+
|
|
1492
|
+
|
|
1493
|
+
class RandomGammaGrad(Primitive):
|
|
1494
|
+
r"""
|
|
1495
|
+
Computes the derivative of a random sample of Gamma with respect to alpha.:
|
|
1496
|
+
|
|
1497
|
+
Inputs:
|
|
1498
|
+
- **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
|
|
1499
|
+
It must be greater than 0. Must be one of the following types: float32, float64.
|
|
1500
|
+
- **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
|
|
1501
|
+
following types: float32, float64.
|
|
1502
|
+
|
|
1503
|
+
Outputs:
|
|
1504
|
+
The dtype is the same type as alpha.
|
|
1505
|
+
The output shape is derived from the input through broadcasting.
|
|
1506
|
+
|
|
1507
|
+
Raises:
|
|
1508
|
+
TypeError: If data type of `alpha` and `sample` is not float32 or float64.
|
|
1509
|
+
TypeError: If data type of `alpha` and `sample` is not same.
|
|
1510
|
+
ValueError: If the shape last dim of `sample` and `alpha` is not equal.
|
|
1511
|
+
|
|
1512
|
+
Supported Platforms:
|
|
1513
|
+
``GPU``
|
|
1514
|
+
|
|
1515
|
+
Examples:
|
|
1516
|
+
>>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
|
|
1517
|
+
>>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
|
|
1518
|
+
>>> randomgammagrad = ops.RandomGammaGrad()
|
|
1519
|
+
>>> output = randomgammagrad(alpha, sample)
|
|
1520
|
+
>>> print(output)
|
|
1521
|
+
[2.5142431 3.4334087 1.8847835 0.07780622]
|
|
1522
|
+
"""
|
|
1523
|
+
|
|
1524
|
+
@prim_attr_register
|
|
1525
|
+
def __init__(self):
|
|
1526
|
+
"""Initialize RandomGammaGrad"""
|
|
1527
|
+
self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
|
|
1528
|
+
self.add_prim_attr("side_effect_hidden", True)
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
class ROIAlignGrad(Primitive):
|
|
1532
|
+
"""
|
|
1533
|
+
ROIAlignGrad operator.
|
|
1534
|
+
|
|
1535
|
+
Args:
|
|
1536
|
+
pooled_height (int): The output feature height.
|
|
1537
|
+
pooled_width (int): The output feature width.
|
|
1538
|
+
spatial_scale (float): The feature stride.
|
|
1539
|
+
sample_num (int): Number of sampling points. Default: 2.
|
|
1540
|
+
"""
|
|
1541
|
+
|
|
1542
|
+
@prim_attr_register
|
|
1543
|
+
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2):
|
|
1544
|
+
"""Initialize ROIAlignGrad"""
|
|
1545
|
+
self.init_prim_io_names(inputs=["dy", "rois", "xdiff_shape"], outputs=["dx"])
|
|
1546
|
+
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
|
|
1547
|
+
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
|
|
1548
|
+
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
|
1549
|
+
validator.check_value_type("sample_num", sample_num, [int], self.name)
|
|
1550
|
+
self.pooled_height = pooled_height
|
|
1551
|
+
self.pooled_width = pooled_width
|
|
1552
|
+
self.spatial_scale = spatial_scale
|
|
1553
|
+
self.sample_num = sample_num
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
class PsROIPoolingGrad(PrimitiveWithInfer):
|
|
1557
|
+
"""
|
|
1558
|
+
PsROIPoolingGrad operator.
|
|
1559
|
+
"""
|
|
1560
|
+
|
|
1561
|
+
@prim_attr_register
|
|
1562
|
+
def __init__(self, batch_size, channels, height, width, num_rois,
|
|
1563
|
+
pooled_height, pooled_width, spatial_scale, out_dim):
|
|
1564
|
+
"""Initialize PsROIPoolingGrad"""
|
|
1565
|
+
validator.check_value_type("batch_size", batch_size, [int], self.name)
|
|
1566
|
+
validator.check_value_type("channels", channels, [int], self.name)
|
|
1567
|
+
validator.check_value_type("height", height, [int], self.name)
|
|
1568
|
+
validator.check_value_type("width", width, [int], self.name)
|
|
1569
|
+
validator.check_value_type("num_rois", num_rois, [int], self.name)
|
|
1570
|
+
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
|
|
1571
|
+
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
|
|
1572
|
+
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
|
1573
|
+
validator.check_value_type("out_dim", out_dim, [int], self.name)
|
|
1574
|
+
self.batch_size = batch_size
|
|
1575
|
+
self.channels = channels
|
|
1576
|
+
self.height = height
|
|
1577
|
+
self.width = width
|
|
1578
|
+
self.num_rois = num_rois
|
|
1579
|
+
self.pooled_height = pooled_height
|
|
1580
|
+
self.pooled_width = pooled_width
|
|
1581
|
+
self.spatial_scale = spatial_scale
|
|
1582
|
+
self.out_dim = out_dim
|
|
1583
|
+
|
|
1584
|
+
def infer_shape(self, ydiff_shape, rois_shape, mapping_channel_shape):
|
|
1585
|
+
return [self.batch_size, self.channels, self.height, self.width]
|
|
1586
|
+
|
|
1587
|
+
def infer_dtype(self, ydiff_type, rois_type, mapping_channel_type):
|
|
1588
|
+
return ydiff_type
|
|
1589
|
+
|
|
1590
|
+
|
|
1591
|
+
class _ActivationGrad(PrimitiveWithInfer):
|
|
1592
|
+
"""_ActivationGrad base class."""
|
|
1593
|
+
|
|
1594
|
+
@prim_attr_register
|
|
1595
|
+
def __init__(self):
|
|
1596
|
+
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
|
1597
|
+
|
|
1598
|
+
def infer_shape(self, y_grad_shape, x_shape):
|
|
1599
|
+
return x_shape
|
|
1600
|
+
|
|
1601
|
+
def infer_dtype(self, y_grad_dtype, x_dtype):
|
|
1602
|
+
valid_dtypes = (mstype.float16, mstype.float32)
|
|
1603
|
+
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
|
|
1604
|
+
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
|
|
1605
|
+
return x_dtype
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
class SigmoidCrossEntropyWithLogitsGrad(Primitive):
|
|
1609
|
+
"""Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
|
|
1610
|
+
|
|
1611
|
+
@prim_attr_register
|
|
1612
|
+
def __init__(self):
|
|
1613
|
+
"""Initialize SigmoidCrossEntropyWithLogitsGrad"""
|
|
1614
|
+
self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
|
|
1615
|
+
|
|
1616
|
+
|
|
1617
|
+
class SliceGrad(PrimitiveWithInfer):
|
|
1618
|
+
"""Reverse of slice."""
|
|
1619
|
+
|
|
1620
|
+
@prim_attr_register
|
|
1621
|
+
def __init__(self):
|
|
1622
|
+
"""Initialize SliceGrad"""
|
|
1623
|
+
self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
|
|
1624
|
+
|
|
1625
|
+
def __infer__(self, dy, x, begin, size):
|
|
1626
|
+
dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
|
|
1627
|
+
dy_shape_len = len(dy_shape)
|
|
1628
|
+
if size_value is not None and not is_shape_unknown(x_shape) and not is_shape_unknown(dy_shape):
|
|
1629
|
+
size_value = list(size_value)
|
|
1630
|
+
for i in range(dy_shape_len):
|
|
1631
|
+
if size_value[i] == -1:
|
|
1632
|
+
size_value[i] = x_shape[i] - begin_v[i]
|
|
1633
|
+
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], validator.LE, self.name)
|
|
1634
|
+
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]',
|
|
1635
|
+
size_value[i], validator.EQ, self.name)
|
|
1636
|
+
|
|
1637
|
+
return {'shape': x_shape,
|
|
1638
|
+
'dtype': x['dtype'],
|
|
1639
|
+
'value': None}
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
class SmoothL1LossGrad(Primitive):
|
|
1643
|
+
"""Computes gradient for prediction on SmoothL1Loss."""
|
|
1644
|
+
|
|
1645
|
+
@prim_attr_register
|
|
1646
|
+
def __init__(self, beta=1.0, reduction='none'):
|
|
1647
|
+
self.add_prim_attr('sigma', self.beta)
|
|
1648
|
+
self.reduction = validator.check_string(
|
|
1649
|
+
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
class SoftMarginLossGrad(Primitive):
|
|
1653
|
+
"""Computes gradient for prediction on SoftMarginLoss."""
|
|
1654
|
+
|
|
1655
|
+
@prim_attr_register
|
|
1656
|
+
def __init__(self, reduction="mean"):
|
|
1657
|
+
self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
|
|
1658
|
+
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
1659
|
+
|
|
1660
|
+
|
|
1661
|
+
class StridedSliceGrad(Primitive):
|
|
1662
|
+
"""
|
|
1663
|
+
Performs grad of StridedSlice operation.
|
|
1664
|
+
|
|
1665
|
+
Args:
|
|
1666
|
+
begin_mask (int): Start indexing the slice. Default: 0.
|
|
1667
|
+
end_mask (int): End indexing the slice. Default: 0.
|
|
1668
|
+
ellipsis_mask (int): An int32 mask. Default: 0.
|
|
1669
|
+
new_axis_mask (int): An int32 mask. Default: 0.
|
|
1670
|
+
shrink_axis_mask (int): An int32 mask. Default: 0.
|
|
1671
|
+
|
|
1672
|
+
Returns:
|
|
1673
|
+
Tensor, has the same shape of input.
|
|
1674
|
+
"""
|
|
1675
|
+
|
|
1676
|
+
@prim_attr_register
|
|
1677
|
+
def __init__(self,
|
|
1678
|
+
begin_mask=0,
|
|
1679
|
+
end_mask=0,
|
|
1680
|
+
ellipsis_mask=0,
|
|
1681
|
+
new_axis_mask=0,
|
|
1682
|
+
shrink_axis_mask=0):
|
|
1683
|
+
"""Initialize StridedSliceGrad"""
|
|
1684
|
+
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
|
1685
|
+
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
|
1686
|
+
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
|
1687
|
+
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
|
1688
|
+
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
|
1689
|
+
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
|
1690
|
+
|
|
1691
|
+
|
|
1692
|
+
class SoftplusGrad(Primitive):
|
|
1693
|
+
"""Computes gradient for the Softplus activation."""
|
|
1694
|
+
|
|
1695
|
+
@prim_attr_register
|
|
1696
|
+
def __init__(self):
|
|
1697
|
+
self.init_prim_io_names(inputs=['gradients', 'features'], outputs=['backprops'])
|
|
1698
|
+
|
|
1699
|
+
|
|
1700
|
+
class TanhGrad(Primitive):
|
|
1701
|
+
"""Computes gradient of hyperbolic tangent of input element-wise."""
|
|
1702
|
+
|
|
1703
|
+
@prim_attr_register
|
|
1704
|
+
def __init__(self):
|
|
1705
|
+
"""Initialize TanhGrad"""
|
|
1706
|
+
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
|
|
1707
|
+
|
|
1708
|
+
|
|
1709
|
+
class MirrorPadGrad(Primitive):
|
|
1710
|
+
"""Gradients of MirrorPad operation."""
|
|
1711
|
+
|
|
1712
|
+
@prim_attr_register
|
|
1713
|
+
def __init__(self, mode="REFLECT"):
|
|
1714
|
+
"""Initialize MirrorPad"""
|
|
1715
|
+
self.init_prim_io_names(inputs=['dy', 'paddings'], outputs=['output'])
|
|
1716
|
+
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
|
|
1717
|
+
self.mode = mode
|
|
1718
|
+
|
|
1719
|
+
|
|
1720
|
+
class PadV3Grad(Primitive):
|
|
1721
|
+
"""Gradients of PadV3 operation."""
|
|
1722
|
+
|
|
1723
|
+
@prim_attr_register
|
|
1724
|
+
def __init__(self, mode='reflect', paddings_contiguous=True):
|
|
1725
|
+
"""Initialize Padv3Grad"""
|
|
1726
|
+
self.add_prim_attr("cust_aicpu", self.name)
|
|
1727
|
+
self.init_prim_io_names(inputs=['x', 'paddings'], outputs=['y'])
|
|
1728
|
+
validator.check_string(mode, ['reflect', 'edge', 'circular'], 'mode', self.name)
|
|
1729
|
+
validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
|
|
1730
|
+
self.mode = mode
|
|
1731
|
+
self.paddings_contiguous = paddings_contiguous
|
|
1732
|
+
|
|
1733
|
+
|
|
1734
|
+
class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
|
1735
|
+
"""
|
|
1736
|
+
Performs the gradient for the communication part of EmbeddingLookup operator.
|
|
1737
|
+
|
|
1738
|
+
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
|
|
1739
|
+
this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
|
|
1740
|
+
"""
|
|
1741
|
+
|
|
1742
|
+
@prim_attr_register
|
|
1743
|
+
def __init__(self):
|
|
1744
|
+
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
|
1745
|
+
self.set_device('CPU')
|
|
1746
|
+
self.tuple_setitem = Primitive('tuple_setitem')
|
|
1747
|
+
|
|
1748
|
+
def __infer__(self, dy, split_num):
|
|
1749
|
+
"""
|
|
1750
|
+
This primitive is implemented by three steps:
|
|
1751
|
+
1) Splits the 'dy' along dimension 0 into 'split_num' parts.
|
|
1752
|
+
2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
|
|
1753
|
+
3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
|
|
1754
|
+
along dimension 0.
|
|
1755
|
+
|
|
1756
|
+
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
|
|
1757
|
+
"""
|
|
1758
|
+
dy_shape = tuple(dy['shape'])
|
|
1759
|
+
split_num_value = split_num['value']
|
|
1760
|
+
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
|
|
1761
|
+
dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
|
|
1762
|
+
return {'shape': dy_shape_all,
|
|
1763
|
+
'dtype': dy['dtype'],
|
|
1764
|
+
'value': None}
|
|
1765
|
+
|
|
1766
|
+
|
|
1767
|
+
class RefToEmbed(Primitive):
|
|
1768
|
+
r"""
|
|
1769
|
+
Make a key from Ref.
|
|
1770
|
+
|
|
1771
|
+
The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
|
|
1772
|
+
and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref.
|
|
1773
|
+
|
|
1774
|
+
Inputs:
|
|
1775
|
+
- **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
|
|
1776
|
+
|
|
1777
|
+
Outputs:
|
|
1778
|
+
symbolic_key, made from the Ref.
|
|
1779
|
+
|
|
1780
|
+
Examples:
|
|
1781
|
+
>>> class Net(nn.Cell):
|
|
1782
|
+
>>> def __init__(self):
|
|
1783
|
+
>>> super(Net, self).__init__()
|
|
1784
|
+
>>> self.weight = mindspore.Parameter(1.0, name='weight')
|
|
1785
|
+
>>>
|
|
1786
|
+
>>> def construct(self):
|
|
1787
|
+
>>> key = RefToEmbed()(self.weight)
|
|
1788
|
+
>>> return key, self.weight
|
|
1789
|
+
"""
|
|
1790
|
+
__mindspore_signature__ = (
|
|
1791
|
+
sig.make_sig('variable', sig.sig_rw.RW_REF),
|
|
1792
|
+
)
|
|
1793
|
+
|
|
1794
|
+
@prim_attr_register
|
|
1795
|
+
def __init__(self):
|
|
1796
|
+
pass
|
|
1797
|
+
|
|
1798
|
+
|
|
1799
|
+
class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
|
|
1800
|
+
"""Computes the state gradients of BasicLSTMCell."""
|
|
1801
|
+
|
|
1802
|
+
@prim_attr_register
|
|
1803
|
+
def __init__(self, forget_bias, activation):
|
|
1804
|
+
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
1805
|
+
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
|
1806
|
+
|
|
1807
|
+
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
|
|
1808
|
+
# dhy and dcy should be same shape
|
|
1809
|
+
validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
|
|
1810
|
+
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1811
|
+
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1812
|
+
validator.check("it rank", len(it_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1813
|
+
validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1814
|
+
validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1815
|
+
validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1816
|
+
validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), validator.EQ, self.name)
|
|
1817
|
+
validator.check("dht shape", dht_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1818
|
+
validator.check("dct shape", dct_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1819
|
+
validator.check("it shape", it_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1820
|
+
validator.check("jt shape", jt_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1821
|
+
validator.check("ft shape", ft_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1822
|
+
validator.check("ot shape", ot_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1823
|
+
validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, validator.EQ, self.name)
|
|
1824
|
+
|
|
1825
|
+
dgate_shape = (c_shape[0], 4 * c_shape[1])
|
|
1826
|
+
dct_1_shape = c_shape
|
|
1827
|
+
|
|
1828
|
+
return (dgate_shape, dct_1_shape)
|
|
1829
|
+
|
|
1830
|
+
def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
|
|
1831
|
+
validator.check_subclass("c", c_dtype, [mstype.tensor_type], self.name)
|
|
1832
|
+
validator.check_subclass("dht", dht_dtype, [mstype.tensor_type], self.name)
|
|
1833
|
+
validator.check_subclass("dct", dct_dtype, [mstype.tensor_type], self.name)
|
|
1834
|
+
validator.check_subclass("it", it_dtype, [mstype.tensor_type], self.name)
|
|
1835
|
+
validator.check_subclass("jt", jt_dtype, [mstype.tensor_type], self.name)
|
|
1836
|
+
validator.check_subclass("ft", ft_dtype, [mstype.tensor_type], self.name)
|
|
1837
|
+
validator.check_subclass("ot", ot_dtype, [mstype.tensor_type], self.name)
|
|
1838
|
+
validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor_type], self.name)
|
|
1839
|
+
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1840
|
+
validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1841
|
+
validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1842
|
+
validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1843
|
+
validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1844
|
+
validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1845
|
+
validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1846
|
+
validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1847
|
+
return (c_dtype, c_dtype)
|
|
1848
|
+
|
|
1849
|
+
|
|
1850
|
+
class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
|
|
1851
|
+
"""Computes the weight gradients of BasicLSTM."""
|
|
1852
|
+
|
|
1853
|
+
@prim_attr_register
|
|
1854
|
+
def __init__(self):
|
|
1855
|
+
pass
|
|
1856
|
+
|
|
1857
|
+
def infer_shape(self, x_shape, h_shape, dgate_shape):
|
|
1858
|
+
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
|
|
1859
|
+
validator.check("h rank", len(h_shape), " x rank", len(x_shape), validator.EQ, self.name)
|
|
1860
|
+
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), validator.EQ, self.name)
|
|
1861
|
+
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], validator.EQ, self.name)
|
|
1862
|
+
validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], validator.EQ, self.name)
|
|
1863
|
+
validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], validator.EQ, self.name)
|
|
1864
|
+
input_size = x_shape[1]
|
|
1865
|
+
hidden_size = h_shape[1]
|
|
1866
|
+
dw_shape = (input_size + hidden_size, 4 * hidden_size)
|
|
1867
|
+
db_shape = (4 * hidden_size,)
|
|
1868
|
+
return (dw_shape, db_shape)
|
|
1869
|
+
|
|
1870
|
+
def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
|
|
1871
|
+
validator.check_subclass("x", x_dtype, mstype.tensor_type, self.name)
|
|
1872
|
+
validator.check_subclass("h", h_dtype, mstype.tensor_type, self.name)
|
|
1873
|
+
validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
|
|
1874
|
+
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1875
|
+
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1876
|
+
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1877
|
+
return (x_dtype, x_dtype)
|
|
1878
|
+
|
|
1879
|
+
|
|
1880
|
+
class BasicLSTMCellInputGrad(PrimitiveWithInfer):
|
|
1881
|
+
"""Computes the input gradients of BasicLSTM."""
|
|
1882
|
+
|
|
1883
|
+
@prim_attr_register
|
|
1884
|
+
def __init__(self, keep_prob):
|
|
1885
|
+
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
|
1886
|
+
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, validator.INC_BOTH, "keep_prob", self.name)
|
|
1887
|
+
|
|
1888
|
+
def infer_shape(self, dgate_shape, w_shape):
|
|
1889
|
+
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
|
|
1890
|
+
validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
|
|
1891
|
+
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], validator.EQ, self.name)
|
|
1892
|
+
batch_size = dgate_shape[0]
|
|
1893
|
+
hidden_size = dgate_shape[1] // 4
|
|
1894
|
+
input_size = w_shape[0] - hidden_size
|
|
1895
|
+
dxt_shape = (batch_size, input_size)
|
|
1896
|
+
dht_shape = (batch_size, hidden_size)
|
|
1897
|
+
return (dxt_shape, dht_shape)
|
|
1898
|
+
|
|
1899
|
+
def infer_dtype(self, dgate_dtype, w_dtype):
|
|
1900
|
+
validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
|
|
1901
|
+
validator.check_subclass("w", w_dtype, mstype.tensor_type, self.name)
|
|
1902
|
+
validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1903
|
+
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1904
|
+
return (dgate_dtype, dgate_dtype)
|
|
1905
|
+
|
|
1906
|
+
|
|
1907
|
+
class InvGrad(Primitive):
|
|
1908
|
+
"""Computes gradients for inv operation."""
|
|
1909
|
+
|
|
1910
|
+
@prim_attr_register
|
|
1911
|
+
def __init__(self):
|
|
1912
|
+
self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
|
|
1913
|
+
|
|
1914
|
+
|
|
1915
|
+
class LRNGrad(Primitive):
|
|
1916
|
+
"""Computes gradients for LRN operation."""
|
|
1917
|
+
|
|
1918
|
+
@prim_attr_register
|
|
1919
|
+
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
|
|
1920
|
+
self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
|
|
1921
|
+
validator.check_value_type("depth_radius", depth_radius, [int], self.name)
|
|
1922
|
+
validator.check_value_type("bias", bias, [float], self.name)
|
|
1923
|
+
validator.check_value_type("alpha", alpha, [float], self.name)
|
|
1924
|
+
validator.check_value_type("beta", beta, [float], self.name)
|
|
1925
|
+
|
|
1926
|
+
|
|
1927
|
+
class MvlgammaGrad(Primitive):
|
|
1928
|
+
r"""
|
|
1929
|
+
Computes gradients for Mvlgamma.
|
|
1930
|
+
|
|
1931
|
+
The following tex shows the mathematical calculation process of Mvlgamma:
|
|
1932
|
+
|
|
1933
|
+
.. math::
|
|
1934
|
+
|
|
1935
|
+
\log (\Gamma_{p}(a))=C+\sum_{i=1}^{p} \log (\Gamma(a-\frac{i-1}{2}))
|
|
1936
|
+
|
|
1937
|
+
where :math:`C = \log(\pi) \times \frac{p(p-1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function.
|
|
1938
|
+
|
|
1939
|
+
Args:
|
|
1940
|
+
p(int): The number of dimensions. And the value of `p` must be greater than or equal to 1.
|
|
1941
|
+
|
|
1942
|
+
Inputs:
|
|
1943
|
+
- **y_grad** (Tensor) - The input gradient.
|
|
1944
|
+
- **x** (Tensor) - The input of Mvlgamma with data type of float32 or float64.
|
|
1945
|
+
|
|
1946
|
+
Outputs:
|
|
1947
|
+
Tensor, has the same shape and type as `x`.
|
|
1948
|
+
|
|
1949
|
+
Raises:
|
|
1950
|
+
TypeError: If dtype of `y_grad or `x` is neither float32 nor float64.
|
|
1951
|
+
TypeError: If `p` is not an int.
|
|
1952
|
+
ValueError: If p is not greater than or equal to 1.
|
|
1953
|
+
ValueError: If all elements of `x` are not greater than (p-1)/2.
|
|
1954
|
+
|
|
1955
|
+
Supported Platforms:
|
|
1956
|
+
``Ascend`` ``CPU``
|
|
1957
|
+
"""
|
|
1958
|
+
|
|
1959
|
+
@prim_attr_register
|
|
1960
|
+
def __init__(self, p):
|
|
1961
|
+
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['x_grad'])
|
|
1962
|
+
self.p = validator.check_value_type('p', p, [int], self.name)
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
class CdistGrad(Primitive):
|
|
1966
|
+
"""Computes gradient for Cdist."""
|
|
1967
|
+
|
|
1968
|
+
@prim_attr_register
|
|
1969
|
+
def __init__(self, p=2.0):
|
|
1970
|
+
validator.check_value_type("p", p, [float], self.name)
|
|
1971
|
+
self.init_prim_io_names(inputs=['grad', 'input_x', 'input_y', 'cdist'], outputs=['output'])
|
|
1972
|
+
|
|
1973
|
+
|
|
1974
|
+
class PdistGrad(Primitive):
|
|
1975
|
+
"""Computes gradient for Pdist operation.
|
|
1976
|
+
|
|
1977
|
+
Args:
|
|
1978
|
+
p (float): the p value for the Pdist formulation. Default: 2.0.
|
|
1979
|
+
|
|
1980
|
+
Inputs:
|
|
1981
|
+
- **y_grad** (Tensor) - The gradients of loss to output of Pdist function.
|
|
1982
|
+
- **x** (Tensor) - Input tensor of shape :math:`(N, M)`.
|
|
1983
|
+
Must be the input `x` of the forward operator Pdist.
|
|
1984
|
+
- **y** (Tensor) - Input tensor of shape :math:`(N*(N-1)/2)`.
|
|
1985
|
+
Must be the output `y` of the forward operator Pdist.
|
|
1986
|
+
|
|
1987
|
+
Outputs:
|
|
1988
|
+
Tensor, with the same shape and dtype as `x`.
|
|
1989
|
+
|
|
1990
|
+
Raises:
|
|
1991
|
+
TypeError: If one of `y_grad`, `x` and `y` is not a Tensor.
|
|
1992
|
+
TypeError: If dtype of `y_grad`, `x` and `y` are not all float16, float32 or float64.
|
|
1993
|
+
TypeError: If `p` is not a float.
|
|
1994
|
+
ValueError: If `p` is a negative float.
|
|
1995
|
+
ValueError: If shape of `y_grad` is not same as `y`.
|
|
1996
|
+
ValueError: If dimension of `x` is not 2.
|
|
1997
|
+
|
|
1998
|
+
Supported Platforms:
|
|
1999
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2000
|
+
"""
|
|
2001
|
+
|
|
2002
|
+
@prim_attr_register
|
|
2003
|
+
def __init__(self, p=2.0):
|
|
2004
|
+
validator.check_value_type("p", p, [float], self.name)
|
|
2005
|
+
if p < 0:
|
|
2006
|
+
raise ValueError('Pdist p must be a non-negative value, but got `{p}`.')
|
|
2007
|
+
self.init_prim_io_names(inputs=['y_grad', 'x', 'y'], outputs=['x_grad'])
|
|
2008
|
+
|
|
2009
|
+
|
|
2010
|
+
class MultilabelMarginLossGrad(Primitive):
|
|
2011
|
+
"""
|
|
2012
|
+
Compute the gradients of MultilabelMarginLoss operation.
|
|
2013
|
+
|
|
2014
|
+
Args:
|
|
2015
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2016
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2017
|
+
|
|
2018
|
+
- ``'none'``: no reduction will be applied.
|
|
2019
|
+
- ``'mean'``: compute and return the mean of elements in the output.
|
|
2020
|
+
- ``'sum'``: the output elements will be summed.
|
|
2021
|
+
|
|
2022
|
+
Inputs:
|
|
2023
|
+
- **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
|
|
2024
|
+
the same shape and data type as forward output `y`.
|
|
2025
|
+
- **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N`
|
|
2026
|
+
is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32.
|
|
2027
|
+
- **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and
|
|
2028
|
+
label targets padded by -1.
|
|
2029
|
+
- **is_target** (Tensor) - Forward output tensor for backward input, with the same shape and
|
|
2030
|
+
data type as `target`.
|
|
2031
|
+
|
|
2032
|
+
Outputs:
|
|
2033
|
+
The shape of output :math:`(C)` or :math:`(N, C)`, with the same shape and data type as `x`.
|
|
2034
|
+
|
|
2035
|
+
Raises:
|
|
2036
|
+
TypeError: If `x` or `target` or `y_grad` is not a Tensor.
|
|
2037
|
+
TypeError: If dtype of `x` is neither float16 nor float32.
|
|
2038
|
+
TypeError: If dtype of `target` is not int32.
|
|
2039
|
+
TypeError: If dtype of `y_grad` is not the same as `x`.
|
|
2040
|
+
ValueError: If length of shape of `x` is neither 1 nor 2.
|
|
2041
|
+
ValueError: If shape of `x` is not the same as `target`.
|
|
2042
|
+
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
2043
|
+
ValueError: If shape of `y_grad` is not the same as forward output `y`.
|
|
2044
|
+
|
|
2045
|
+
Supported Platforms:
|
|
2046
|
+
``Ascend``
|
|
2047
|
+
"""
|
|
2048
|
+
|
|
2049
|
+
@prim_attr_register
|
|
2050
|
+
def __init__(self, reduction="mean"):
|
|
2051
|
+
"""Initialize MultilabelMarginLossGrad"""
|
|
2052
|
+
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
2053
|
+
self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'is_target'], outputs=['x_grad'])
|
|
2054
|
+
|
|
2055
|
+
|
|
2056
|
+
class Dilation2DBackpropInput(Primitive):
|
|
2057
|
+
"""
|
|
2058
|
+
Computes the gradient of morphological 2-D dilation with respect to the input.
|
|
2059
|
+
|
|
2060
|
+
.. warning::
|
|
2061
|
+
This operator is an experimental operator, which has some accuracy problems for some inputs.
|
|
2062
|
+
|
|
2063
|
+
Args:
|
|
2064
|
+
stride (Union[int, tuple[int]]): The distance of filter moving, an int number that represents
|
|
2065
|
+
the height and width of movement are both strides, a tuple of two int numbers that
|
|
2066
|
+
represent height and width of movement respectively, or a tuple of four int numbers which
|
|
2067
|
+
should be :math:`(1, 1, H_{stride}, W_{stride})`.
|
|
2068
|
+
dilation (Union[int, tuple[int]]): The input stride for atrous morphological dilation.The data
|
|
2069
|
+
type is int or a tuple of 2 or 4 integers. Its value must be greater or equal to 1 and bounded
|
|
2070
|
+
by the height and width of the input `x`.
|
|
2071
|
+
pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
|
|
2072
|
+
Default: "same". Both upper and lower case are supported.
|
|
2073
|
+
data_format (str): The format of input and output data. Only NCHW format is supported at present.
|
|
2074
|
+
Default:'NCHW'
|
|
2075
|
+
|
|
2076
|
+
Inputs:
|
|
2077
|
+
- **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
|
|
2078
|
+
:math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
2079
|
+
- **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
|
|
2080
|
+
:math:`(C_{in}, H_{filter}, W_{filter})`.
|
|
2081
|
+
- **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
|
|
2082
|
+
A four dimension tensor with float16 or float32 data type. The shape must be
|
|
2083
|
+
:math:`(N, C_{in}, H_{out}, W_{out})`.
|
|
2084
|
+
|
|
2085
|
+
outputs:
|
|
2086
|
+
Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
|
|
2087
|
+
|
|
2088
|
+
Raises:
|
|
2089
|
+
TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
|
|
2090
|
+
int32, int64, float16, float32, float64].
|
|
2091
|
+
TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
|
|
2092
|
+
int32, int64, float16, float32, float64].
|
|
2093
|
+
TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
|
|
2094
|
+
ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
|
|
2095
|
+
ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
|
|
2096
|
+
ValueError: If `stride` is not in the range of [1, 255].
|
|
2097
|
+
ValueError: If `dilation` is less than 1.
|
|
2098
|
+
ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
|
|
2099
|
+
ValueError: If `data_format` is not the str of 'NCHW'.
|
|
2100
|
+
|
|
2101
|
+
Supported Platforms:
|
|
2102
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2103
|
+
|
|
2104
|
+
Examples:
|
|
2105
|
+
(pad_mode="SAME", data_format="NCHW")
|
|
2106
|
+
>>> out_backprop = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
|
|
2107
|
+
>>> filter = Tensor(np.ones([3 , 2 , 2]), mstype.float32)
|
|
2108
|
+
>>> x = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
|
|
2109
|
+
>>> dilation_backprop_input = G.Dilation2DBackpropInput(stride=1, dilation=1)
|
|
2110
|
+
>>> output = dilation_backprop_input(x, filter, out_backprop)
|
|
2111
|
+
>>> print(output)
|
|
2112
|
+
[[[[1. 1. 1. 1.]
|
|
2113
|
+
[1. 1. 1. 1.]
|
|
2114
|
+
[1. 1. 1. 1.]
|
|
2115
|
+
[1. 1. 1. 1.]]
|
|
2116
|
+
[[1. 1. 1. 1.]
|
|
2117
|
+
[1. 1. 1. 1.]
|
|
2118
|
+
[1. 1. 1. 1.]
|
|
2119
|
+
[1. 1. 1. 1.]]
|
|
2120
|
+
[[1. 1. 1. 1.]
|
|
2121
|
+
[1. 1. 1. 1.]
|
|
2122
|
+
[1. 1. 1. 1.]
|
|
2123
|
+
[1. 1. 1. 1.]]]]
|
|
2124
|
+
"""
|
|
2125
|
+
|
|
2126
|
+
@prim_attr_register
|
|
2127
|
+
def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
|
|
2128
|
+
"""Initialize Dilation2DBackpropInput"""
|
|
2129
|
+
|
|
2130
|
+
def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
|
|
2131
|
+
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
|
2132
|
+
if isinstance(arg_value, int):
|
|
2133
|
+
ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
|
|
2134
|
+
elif len(arg_value) == 2:
|
|
2135
|
+
ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
|
|
2136
|
+
(1, 1, arg_value[0], arg_value[1])
|
|
2137
|
+
elif len(arg_value) == 4:
|
|
2138
|
+
if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
|
|
2139
|
+
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be "
|
|
2140
|
+
f"[1, {arg_name}_height, {arg_name}_weigth, 1] when data_format is 'NHWC', "
|
|
2141
|
+
f"but got {arg_value}")
|
|
2142
|
+
if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
|
|
2143
|
+
raise ValueError(
|
|
2144
|
+
f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
|
|
2145
|
+
f"when data_format is 'NCHW', but got {arg_value}")
|
|
2146
|
+
ret_value = arg_value
|
|
2147
|
+
else:
|
|
2148
|
+
raise ValueError(
|
|
2149
|
+
f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
|
|
2150
|
+
f"or four positive int numbers, but got {arg_value}")
|
|
2151
|
+
for item in ret_value:
|
|
2152
|
+
if isinstance(item, int) and not isinstance(item, bool) and item > 0:
|
|
2153
|
+
continue
|
|
2154
|
+
raise ValueError(
|
|
2155
|
+
f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
|
|
2156
|
+
f"or four positive int numbers, but got {arg_value}")
|
|
2157
|
+
return ret_value
|
|
2158
|
+
|
|
2159
|
+
if data_format == 'NHWC':
|
|
2160
|
+
raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
|
|
2161
|
+
self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
|
|
2162
|
+
self.add_prim_attr("data_format", self.data_format)
|
|
2163
|
+
self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
|
|
2164
|
+
self.add_prim_attr("pad_mode", self.pad_mode.upper())
|
|
2165
|
+
self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
|
|
2166
|
+
self.add_prim_attr("stride", self.stride)
|
|
2167
|
+
self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
|
|
2168
|
+
self.add_prim_attr("dilation", self.dilation)
|
|
2169
|
+
|
|
2170
|
+
|
|
2171
|
+
class Dilation2DBackpropFilter(Primitive):
|
|
2172
|
+
"""
|
|
2173
|
+
Computes the gradient of morphological 2-D dilation with respect to the filter.
|
|
2174
|
+
|
|
2175
|
+
.. warning::
|
|
2176
|
+
This operator is an experimental operator, which has some accuracy problems for some inputs.
|
|
2177
|
+
|
|
2178
|
+
Args:
|
|
2179
|
+
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
|
2180
|
+
the height and width of movement are both strides, a tuple of two int numbers that
|
|
2181
|
+
represent height and width of movement respectively, or a tuple of four int numbers which
|
|
2182
|
+
should be :math:`(1, 1, H_{stride}, W_{stride})`.
|
|
2183
|
+
dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers or a tuple of 4 integers.
|
|
2184
|
+
Specifies the dilation rate to use for dilated convolution.
|
|
2185
|
+
If set to be :math:`k > 1`, there will be :math:`k - 1` pixels skipped for each sampling location.
|
|
2186
|
+
Its value must be greater or equal to 1 and bounded by the height and width of the input `x`.
|
|
2187
|
+
pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
|
|
2188
|
+
Default: "same". Both upper and lower case are supported.
|
|
2189
|
+
data_format (str): The format of input and output data. Only NCHW format is supported at present.
|
|
2190
|
+
Default:'NCHW'
|
|
2191
|
+
|
|
2192
|
+
Inputs:
|
|
2193
|
+
- **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
|
|
2194
|
+
:math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
2195
|
+
- **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
|
|
2196
|
+
:math:`(C_{in}, H_{filter}, W_{filter})`.
|
|
2197
|
+
- **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
|
|
2198
|
+
A four dimension tensor with float16 or float32 data type. The shape must be
|
|
2199
|
+
:math:`(N, C_{in}, H_{out}, W_{out})`.
|
|
2200
|
+
|
|
2201
|
+
outputs:
|
|
2202
|
+
Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
|
|
2203
|
+
|
|
2204
|
+
Raises:
|
|
2205
|
+
TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
|
|
2206
|
+
int32, int64, float16, float32, float64].
|
|
2207
|
+
TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
|
|
2208
|
+
int32, int64, float16, float32, float64].
|
|
2209
|
+
TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
|
|
2210
|
+
ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
|
|
2211
|
+
ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
|
|
2212
|
+
ValueError: If `stride` is not in the range of [1, 255].
|
|
2213
|
+
ValueError: If `dilation` is less than 1.
|
|
2214
|
+
ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
|
|
2215
|
+
ValueError: If `data_format` is not the str of 'NCHW'.
|
|
2216
|
+
|
|
2217
|
+
|
|
2218
|
+
Supported Platforms:
|
|
2219
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2220
|
+
|
|
2221
|
+
Examples:
|
|
2222
|
+
(pad_mode="SAME", data_format="NCHW")
|
|
2223
|
+
>>> x = Tensor(np.ones([2, 3, 4, 4]), mstype.float32)
|
|
2224
|
+
>>> filter = Tensor(np.ones([3,2,2]), mstype.float32)
|
|
2225
|
+
>>> out_backprop = Tensor(np.ones([2,3,2,2]), mstype.float32)
|
|
2226
|
+
>>> dilation_backprop_filter = G.Dilation2DBackpropFilter(stride=2, dilation=1)
|
|
2227
|
+
>>> output = dilation_backprop_filter(x, filter, out_backprop)
|
|
2228
|
+
>>> print(output)
|
|
2229
|
+
[[[8. 8. 8.]
|
|
2230
|
+
[0. 0. 0.]]
|
|
2231
|
+
[[0. 0. 0.]
|
|
2232
|
+
[0. 0. 0.]]]
|
|
2233
|
+
"""
|
|
2234
|
+
|
|
2235
|
+
@prim_attr_register
|
|
2236
|
+
def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
|
|
2237
|
+
"""Initialize Dilation2DBackpropFilter"""
|
|
2238
|
+
|
|
2239
|
+
def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
|
|
2240
|
+
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
|
2241
|
+
if isinstance(arg_value, int):
|
|
2242
|
+
ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
|
|
2243
|
+
elif len(arg_value) == 2:
|
|
2244
|
+
ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
|
|
2245
|
+
(1, 1, arg_value[0], arg_value[1])
|
|
2246
|
+
elif len(arg_value) == 4:
|
|
2247
|
+
if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
|
|
2248
|
+
raise ValueError(
|
|
2249
|
+
f"For '{prim_name}' attr '{arg_name}' should be [1, {arg_name}_height, {arg_name}_weigth, 1]"
|
|
2250
|
+
f"when data_format is 'NHWC', but got {arg_value}")
|
|
2251
|
+
if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
|
|
2252
|
+
raise ValueError(
|
|
2253
|
+
f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
|
|
2254
|
+
f"when data_format is 'NCHW', but got {arg_value}")
|
|
2255
|
+
ret_value = arg_value
|
|
2256
|
+
else:
|
|
2257
|
+
raise ValueError(
|
|
2258
|
+
f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
|
|
2259
|
+
f"or four positive int numbers, but got {arg_value}")
|
|
2260
|
+
for item in ret_value:
|
|
2261
|
+
if isinstance(item, int) and not isinstance(item, bool) and item > 0:
|
|
2262
|
+
continue
|
|
2263
|
+
raise ValueError(
|
|
2264
|
+
f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
|
|
2265
|
+
f"or four positive int numbers, but got {arg_value}")
|
|
2266
|
+
return ret_value
|
|
2267
|
+
|
|
2268
|
+
if data_format == 'NHWC':
|
|
2269
|
+
raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
|
|
2270
|
+
self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
|
|
2271
|
+
self.add_prim_attr("data_format", self.data_format)
|
|
2272
|
+
self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
|
|
2273
|
+
self.add_prim_attr("pad_mode", self.pad_mode.upper())
|
|
2274
|
+
self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
|
|
2275
|
+
def is_in_range(x):
|
|
2276
|
+
return 1 <= x <= 255
|
|
2277
|
+
if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
|
|
2278
|
+
raise ValueError(f"For '{self.name}', size of stride is not supported, "
|
|
2279
|
+
f'stride should be in the range of [1, 255], '
|
|
2280
|
+
f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
|
|
2281
|
+
self.add_prim_attr("stride", self.stride)
|
|
2282
|
+
self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
|
|
2283
|
+
self.add_prim_attr("dilation", self.dilation)
|
|
2284
|
+
|
|
2285
|
+
|
|
2286
|
+
class ParallelResizeBilinearGrad(PrimitiveWithInfer):
|
|
2287
|
+
"""ParallelResizeBilinearGrad ops"""
|
|
2288
|
+
|
|
2289
|
+
@prim_attr_register
|
|
2290
|
+
def __init__(self, ori_image_size, src_start_w, dst_start_w, align_corners):
|
|
2291
|
+
"""Initialize ParallelResizeBilinearGrad."""
|
|
2292
|
+
self.init_prim_io_names(inputs=["grad", "x", "size"], outputs=['y'])
|
|
2293
|
+
validator.check_value_type("ori_image_size", ori_image_size, [tuple, list], self.name)
|
|
2294
|
+
validator.check_value_type("src_start_w", src_start_w, [int], self.name)
|
|
2295
|
+
validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
|
|
2296
|
+
validator.check_value_type("align_corners", align_corners, [bool], self.name)
|
|
2297
|
+
self.ori_image_size = list(ori_image_size)
|
|
2298
|
+
self.src_start_w = src_start_w
|
|
2299
|
+
self.dst_start_w = dst_start_w
|
|
2300
|
+
self.align_corners = align_corners
|
|
2301
|
+
self.half_pixel_centers = False
|
|
2302
|
+
self.add_prim_attr('ori_image_size', self.ori_image_size)
|
|
2303
|
+
self.add_prim_attr('src_start_w', self.src_start_w)
|
|
2304
|
+
self.add_prim_attr('dst_start_w', self.dst_start_w)
|
|
2305
|
+
self.add_prim_attr('align_corners', self.align_corners)
|
|
2306
|
+
self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
|
|
2307
|
+
|
|
2308
|
+
def __infer__(self, grad, x, size):
|
|
2309
|
+
size_val = size['value']
|
|
2310
|
+
grad_shape = grad['shape']
|
|
2311
|
+
grad_dtype = grad['dtype']
|
|
2312
|
+
x_shape = x['shape']
|
|
2313
|
+
x_dtype = x['dtype']
|
|
2314
|
+
validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
|
|
2315
|
+
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
|
|
2316
|
+
if size_val is None:
|
|
2317
|
+
raise ValueError("size must be const input")
|
|
2318
|
+
output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]
|
|
2319
|
+
|
|
2320
|
+
return {'shape': output_shape,
|
|
2321
|
+
'dtype': x_dtype,
|
|
2322
|
+
'value': None}
|
|
2323
|
+
|
|
2324
|
+
|
|
2325
|
+
class MultiMarginLossGrad(Primitive):
|
|
2326
|
+
"""
|
|
2327
|
+
Compute the gradients of MultiMarginLoss operation
|
|
2328
|
+
|
|
2329
|
+
Args:
|
|
2330
|
+
p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
|
|
2331
|
+
margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
|
|
2332
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2333
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2334
|
+
|
|
2335
|
+
- ``'none'``: no reduction will be applied.
|
|
2336
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
2337
|
+
- ``'sum'``: the output elements will be summed.
|
|
2338
|
+
|
|
2339
|
+
Inputs:
|
|
2340
|
+
- **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
|
|
2341
|
+
Data type only support float32 or float16,float64.
|
|
2342
|
+
- **x** (Tensor) - Input x, with shape :math:`(N, C)`. Data type only support float32, float16 or float64.
|
|
2343
|
+
- **target** (Tensor) - Ground truth labels, with shape :math:`(N,)`. Data type only support int64. The
|
|
2344
|
+
value of target should be non-negative, less than C.
|
|
2345
|
+
- **weight** (Tensor, optional) - The rescaling weight to each class with shape :math:`(C,)`. Data type only
|
|
2346
|
+
support float32, float16 or float64. Default: ``None``.
|
|
2347
|
+
|
|
2348
|
+
Outputs:
|
|
2349
|
+
The shape of output :math:`(N, C)`. Data type only support float32 or float16, float64.
|
|
2350
|
+
Has the same data type with 'x'.
|
|
2351
|
+
|
|
2352
|
+
Raises:
|
|
2353
|
+
TypeError: If dtype of `p` and `target` is not int.
|
|
2354
|
+
TypeError: If dtype of `margin` is not float.
|
|
2355
|
+
TypeError: If dtype of `reduction` is not str.
|
|
2356
|
+
TypeError: If dtype of `x` is not float16, float or float64.
|
|
2357
|
+
TypeError: If dtype of `weight` and `x` is not the same.
|
|
2358
|
+
ValueError: If 'p' is not 1 or 2.
|
|
2359
|
+
ValueError: If 'reduction' is not one of {'none','sum','mean'}.
|
|
2360
|
+
ValueError: If shape[0] of `x` is not equal to shape[0] of `target`.
|
|
2361
|
+
ValueError: If shape[1] of `x` is not equal to shape[0] of `weight`.
|
|
2362
|
+
ValueError: IF rank of `weight` is not 1.
|
|
2363
|
+
ValueError: If rank of `x` is not 2 or rank of 'target' is not 1.
|
|
2364
|
+
|
|
2365
|
+
Supported Platforms:
|
|
2366
|
+
``Ascend`` ``CPU``
|
|
2367
|
+
"""
|
|
2368
|
+
__mindspore_signature__ = (
|
|
2369
|
+
sig.make_sig('y_grad'),
|
|
2370
|
+
sig.make_sig('x'),
|
|
2371
|
+
sig.make_sig('target'),
|
|
2372
|
+
sig.make_sig('weight', default=None)
|
|
2373
|
+
)
|
|
2374
|
+
|
|
2375
|
+
@prim_attr_register
|
|
2376
|
+
def __init__(self, p=1, margin=1.0, reduction="mean"):
|
|
2377
|
+
"""Initialize MultiMarginLossGrad"""
|
|
2378
|
+
self.p = validator.check_value_type('p', p, [int], self.name)
|
|
2379
|
+
validator.check_int(p, {1, 2}, validator.IN, 'p', self.name)
|
|
2380
|
+
self.margin = validator.check_value_type('margin', margin, [float], self.name)
|
|
2381
|
+
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
2382
|
+
self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'weight'], outputs=['x_grad'])
|
|
2383
|
+
|
|
2384
|
+
def __call__(self, y_grad, x, target, weight=None):
|
|
2385
|
+
return super().__call__(y_grad, x, target, weight)
|
|
2386
|
+
|
|
2387
|
+
|
|
2388
|
+
class SparseSegmentMeanGrad(Primitive):
|
|
2389
|
+
"""
|
|
2390
|
+
Compute gradients for SparseSegmentMeanGrad operation.
|
|
2391
|
+
|
|
2392
|
+
Inputs:
|
|
2393
|
+
- **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanGrad.
|
|
2394
|
+
- **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
|
|
2395
|
+
types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`.
|
|
2396
|
+
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one of the
|
|
2397
|
+
following types: int32, int64. Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
|
2398
|
+
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentMean op.
|
|
2399
|
+
|
|
2400
|
+
Outputs:
|
|
2401
|
+
A Tensor. Has the same type as `x` .
|
|
2402
|
+
Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
|
|
2403
|
+
|
|
2404
|
+
Raises:
|
|
2405
|
+
TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
|
|
2406
|
+
TypeError: If the dtype of `x` is not any of the following data types: {float32, float64}.
|
|
2407
|
+
TypeError: If the dtype of `indices` is not int32.
|
|
2408
|
+
TypeError: If the dtype of `segment_ids` is not int32.
|
|
2409
|
+
TypeError: If the dtype of `output_dim0` is not int32.
|
|
2410
|
+
ValueError: If dimension size of `x` is less than 1.
|
|
2411
|
+
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
|
2412
|
+
ValueError: If dimension size of `output_dim0` is not 0.
|
|
2413
|
+
ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
|
|
2414
|
+
ValueError: If `segment_ids` is not sorted.
|
|
2415
|
+
ValueError: If `indices` is out of range of `output_dim0`.
|
|
2416
|
+
|
|
2417
|
+
Supported Platforms:
|
|
2418
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2419
|
+
"""
|
|
2420
|
+
|
|
2421
|
+
@prim_attr_register
|
|
2422
|
+
def __init__(self):
|
|
2423
|
+
"""Initialize SparseSegmentMeanGrad"""
|
|
2424
|
+
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
|
|
2425
|
+
|
|
2426
|
+
|
|
2427
|
+
class FractionalMaxPoolGrad(Primitive):
|
|
2428
|
+
"""Computes gradients for FractionalMaxPool operation."""
|
|
2429
|
+
|
|
2430
|
+
@prim_attr_register
|
|
2431
|
+
def __init__(self, overlapping=False):
|
|
2432
|
+
self.init_prim_io_names(inputs=["orig_input", "orig_output", "out_backprop",
|
|
2433
|
+
"row_pooling_sequence", "col_pooling_sequence"],
|
|
2434
|
+
outputs=["y"])
|
|
2435
|
+
validator.check_value_type("overlapping", overlapping, [bool], self.name)
|
|
2436
|
+
|
|
2437
|
+
|
|
2438
|
+
class FractionalMaxPool3DGradWithFixedKsize(Primitive):
|
|
2439
|
+
"""Computes gradients for FractionalMaxPool3DWithFixedKsize operation."""
|
|
2440
|
+
|
|
2441
|
+
@prim_attr_register
|
|
2442
|
+
def __init__(self, data_format="NCDHW"):
|
|
2443
|
+
self.init_prim_io_names(inputs=["origin_input", "out_backprop", "argmax"], outputs=["y"])
|
|
2444
|
+
self.data_format = validator.check_string(data_format, ['NCDHW', "NDHWC"], 'data_format', self.name)
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
class MaxUnpool2DGrad(Primitive):
|
|
2448
|
+
r"""
|
|
2449
|
+
Gradients for MaxUnpool2D operation.
|
|
2450
|
+
"""
|
|
2451
|
+
|
|
2452
|
+
@prim_attr_register
|
|
2453
|
+
def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCHW"):
|
|
2454
|
+
"""Initialize MaxUnpool2DGrad."""
|
|
2455
|
+
self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
|
|
2456
|
+
validator.check_value_type("ksize", ksize, [int, tuple], self.name)
|
|
2457
|
+
validator.check_value_type("strides", strides, [int, tuple], self.name)
|
|
2458
|
+
validator.check_value_type("pads", pads, [int, tuple], self.name)
|
|
2459
|
+
validator.check_value_type("output_shape", output_shape, [tuple], self.name)
|
|
2460
|
+
validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
|
|
2461
|
+
validator.check_int(len(ksize), 4, validator.EQ, "ksize rank", self.name)
|
|
2462
|
+
validator.check_int(len(strides), 4, validator.EQ, "strides rank", self.name)
|
|
2463
|
+
validator.check_int(len(pads), 4, validator.EQ, "pads rank", self.name)
|
|
2464
|
+
|
|
2465
|
+
|
|
2466
|
+
class MaxUnpool3DGrad(Primitive):
|
|
2467
|
+
r"""
|
|
2468
|
+
Gradients for MaxUnpool3D operation.
|
|
2469
|
+
"""
|
|
2470
|
+
|
|
2471
|
+
@prim_attr_register
|
|
2472
|
+
def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCDHW"):
|
|
2473
|
+
"""Initialize MaxUnpool3DGrad."""
|
|
2474
|
+
self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
|
|
2475
|
+
validator.check_value_type("ksize", ksize, [int, tuple], self.name)
|
|
2476
|
+
validator.check_value_type("strides", strides, [int, tuple], self.name)
|
|
2477
|
+
validator.check_value_type("pads", pads, [int, tuple], self.name)
|
|
2478
|
+
validator.check_value_type("output_shape", output_shape, [tuple], self.name)
|
|
2479
|
+
validator.check_string(data_format, ['NCDHW', 'NDHWC'], 'data_format', self.name)
|
|
2480
|
+
validator.check_int(len(ksize), 5, validator.EQ, "ksize rank", self.name)
|
|
2481
|
+
validator.check_int(len(strides), 5, validator.EQ, "strides rank", self.name)
|
|
2482
|
+
validator.check_int(len(pads), 5, validator.EQ, "pads rank", self.name)
|
|
2483
|
+
|
|
2484
|
+
|
|
2485
|
+
class FractionalAvgPoolGrad(Primitive):
|
|
2486
|
+
"""Computes gradients for FractionalAvgPool operation."""
|
|
2487
|
+
|
|
2488
|
+
@prim_attr_register
|
|
2489
|
+
def __init__(self, overlapping=False):
|
|
2490
|
+
self.add_prim_attr("max_length", 1000000)
|
|
2491
|
+
self.init_prim_io_names(inputs=["orig_input_tensor_shape", "out_backprop", "row_pooling_sequence",
|
|
2492
|
+
"col_pooling_sequence"],
|
|
2493
|
+
outputs=["y"])
|
|
2494
|
+
validator.check_value_type("overlapping", overlapping, [bool], self.name)
|
|
2495
|
+
|
|
2496
|
+
|
|
2497
|
+
class PSROIPoolingGrad(Primitive):
|
|
2498
|
+
"""Computes gradients for PSROIPooling operation."""
|
|
2499
|
+
|
|
2500
|
+
@prim_attr_register
|
|
2501
|
+
def __init__(self, input_size, spatial_scale, group_size, output_dim):
|
|
2502
|
+
"""Initialize PSROIPoolingGrad."""
|
|
2503
|
+
self.init_prim_io_names(inputs=["x", "rois"], outputs=['y'])
|
|
2504
|
+
validator.check_value_type("input_size", input_size, [int, tuple], self.name)
|
|
2505
|
+
validator.check_positive_float(spatial_scale, "spatial_scale", self.name)
|
|
2506
|
+
validator.check_positive_int(group_size, "group_size", self.name)
|
|
2507
|
+
validator.check_positive_int(output_dim, "output_dim", self.name)
|
|
2508
|
+
|
|
2509
|
+
if isinstance(input_size, int):
|
|
2510
|
+
self.input_size = [input_size, input_size]
|
|
2511
|
+
else:
|
|
2512
|
+
self.input_size = list(input_size)
|
|
2513
|
+
|
|
2514
|
+
validator.check_positive_int_sequence(self.input_size, "input_size", self.name)
|
|
2515
|
+
self.spatial_scale = spatial_scale
|
|
2516
|
+
self.group_size = group_size
|
|
2517
|
+
self.output_dim = output_dim
|
|
2518
|
+
|
|
2519
|
+
self.add_prim_attr('input_size', self.input_size)
|
|
2520
|
+
self.add_prim_attr('spatial_scale', self.spatial_scale)
|
|
2521
|
+
self.add_prim_attr('group_size', self.group_size)
|
|
2522
|
+
self.add_prim_attr('output_dim', self.output_dim)
|
|
2523
|
+
|
|
2524
|
+
|
|
2525
|
+
class AdaptiveMaxPool3DGrad(Primitive):
|
|
2526
|
+
"""Computes gradients for AdaptiveMaxPool3D operation."""
|
|
2527
|
+
|
|
2528
|
+
@prim_attr_register
|
|
2529
|
+
def __init__(self):
|
|
2530
|
+
"""Initialize AdaptiveMaxPool3DGrad"""
|
|
2531
|
+
self.init_prim_io_names(inputs=['input_grad', 'x', 'argmax'], outputs=['output_grad'])
|
|
2532
|
+
|
|
2533
|
+
|
|
2534
|
+
class TraceGrad(Primitive):
|
|
2535
|
+
"""
|
|
2536
|
+
Computes grad for Trace operation.
|
|
2537
|
+
|
|
2538
|
+
Inputs:
|
|
2539
|
+
- **y_grad** (Tensor) - the grad of trace to output of Trace function.
|
|
2540
|
+
Currently grad data type support float16, float32, int8, int16, int32, int64,
|
|
2541
|
+
uint8, uint16, uint32, uint64, float64.
|
|
2542
|
+
- **x_shape** (Tensor) - the shape of trace to output of Trace function.
|
|
2543
|
+
Currently shape data type support int32, int64.
|
|
2544
|
+
|
|
2545
|
+
Outputs:
|
|
2546
|
+
x_grad - Tensor, with the same data type as 'y_grad' and shape is x_shape.
|
|
2547
|
+
|
|
2548
|
+
Raises:
|
|
2549
|
+
TypeError: If `x_shape` is not a Tensor.
|
|
2550
|
+
TypeError: If the dtype of `x_shape` is neither int32 nor int64.
|
|
2551
|
+
ValueError: If `x_shape` is not a 1D Tensor.
|
|
2552
|
+
ValueError: If length of shape of `x_shape` is not equal to 2.
|
|
2553
|
+
|
|
2554
|
+
Support Platforms:
|
|
2555
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2556
|
+
"""
|
|
2557
|
+
|
|
2558
|
+
@prim_attr_register
|
|
2559
|
+
def __init__(self):
|
|
2560
|
+
self.init_prim_io_names(inputs=['y_grad', 'x_shape'], outputs=['x_grad'])
|
|
2561
|
+
|
|
2562
|
+
|
|
2563
|
+
class IgammaGradA(Primitive):
|
|
2564
|
+
r"""
|
|
2565
|
+
Computes the gradient of igamma(a, x) wrt a.
|
|
2566
|
+
|
|
2567
|
+
Inputs:
|
|
2568
|
+
- **a** (Tensor) - The input tensor. With float32 or float64 data type.
|
|
2569
|
+
- **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have
|
|
2570
|
+
the same dtype with `a`.
|
|
2571
|
+
|
|
2572
|
+
Outputs:
|
|
2573
|
+
Tensor, has the same dtype as `a` and `x`.
|
|
2574
|
+
|
|
2575
|
+
Raises:
|
|
2576
|
+
TypeError: If a or grad is not a Tensor.
|
|
2577
|
+
TypeError: If dtype of input x and a is not float32 nor float64.
|
|
2578
|
+
TypeError: If x has different dtype with a.
|
|
2579
|
+
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
|
|
2580
|
+
|
|
2581
|
+
Supported Platforms:
|
|
2582
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2583
|
+
|
|
2584
|
+
Examples:
|
|
2585
|
+
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
|
2586
|
+
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
|
2587
|
+
>>> igammagrada = G.IgammaGradA()
|
|
2588
|
+
>>> output = igammagrada(a, x)
|
|
2589
|
+
>>> print (output)
|
|
2590
|
+
[-0.2940046 -0.20153049 -0.13028376 -0.08352186]
|
|
2591
|
+
"""
|
|
2592
|
+
|
|
2593
|
+
@prim_attr_register
|
|
2594
|
+
def __init__(self):
|
|
2595
|
+
"""Initialize IgammaGradA"""
|
|
2596
|
+
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
|
|
2597
|
+
|
|
2598
|
+
|
|
2599
|
+
class DeformableOffsetsGrad(Primitive):
|
|
2600
|
+
r"""
|
|
2601
|
+
Computes gradients of DeformableOffsets operation.
|
|
2602
|
+
Args:
|
|
2603
|
+
strides (tuple[int, int ,int ,int]): A tuple of 4 integers. The stride of sliding windows for height
|
|
2604
|
+
and width for H/W dimension.
|
|
2605
|
+
pads (tuple[int, int ,int ,int]): A tuple of 4 integers.Padding added to H/W dimension of the input.The number
|
|
2606
|
+
of pixels to add to each (top, bottom, left,right) side of the input
|
|
2607
|
+
kernel_size (tuple[int, int]): Kernel size, a tuple of 2 integers.
|
|
2608
|
+
dilations (tuple[int, int, int, int]): A tuple of 4 integers. The dilation factor for each dimension of
|
|
2609
|
+
input. Default:(1, 1, 1, 1)
|
|
2610
|
+
data_format (str): An optional string from:"NCHW", "NHWC".Specify the data format of the input x. Default:
|
|
2611
|
+
"NCHW".
|
|
2612
|
+
deformable_groups (int): Specify the C-axis grouping number of input x. Default: 1.
|
|
2613
|
+
modulated (bool): Specify version of DeformableOffsetsGrad, true means v2, false means v1. Default: ``True``.
|
|
2614
|
+
|
|
2615
|
+
Inputs:
|
|
2616
|
+
- **grad** (Tensor) - The input grad tensor. With float16 or float32 data type.
|
|
2617
|
+
- **x** (Tensor) - The input `x` of DeformableOffsets with data type of float16 or float32.
|
|
2618
|
+
- **offsets** (Tensor) - The input 'offsets' of DeformableOffsets with data type of float16 or float32.
|
|
2619
|
+
|
|
2620
|
+
Outputs:
|
|
2621
|
+
- **grad_x** (Tensor) - The output grad of input `x`. With same dtype and shape of input `x`.
|
|
2622
|
+
- ""grad_offsets** (Tensor) - The output grad of input `offsets`. With same dtype and shape of input `offsets`.
|
|
2623
|
+
|
|
2624
|
+
Supported Platforms:
|
|
2625
|
+
``Ascend````GPU````CPU``
|
|
2626
|
+
"""
|
|
2627
|
+
|
|
2628
|
+
@prim_attr_register
|
|
2629
|
+
def __init__(self,
|
|
2630
|
+
strides,
|
|
2631
|
+
pads,
|
|
2632
|
+
kernel_size,
|
|
2633
|
+
dilations=(1, 1, 1, 1),
|
|
2634
|
+
data_format="NCHW",
|
|
2635
|
+
deformable_groups=1,
|
|
2636
|
+
modulated=True):
|
|
2637
|
+
"""Initialize DeformableOffsetsGrad"""
|
|
2638
|
+
self.init_prim_io_names(inputs=['out_backprop', 'input', 'offsets'], outputs=['out_grad'])
|
|
2639
|
+
|
|
2640
|
+
self.strides = _check_positive_int_or_tuple('strides', strides, self.name, allow_four=True, ret_four=True)
|
|
2641
|
+
self.add_prim_attr('strides', self.strides)
|
|
2642
|
+
|
|
2643
|
+
self.pads = pads
|
|
2644
|
+
self.add_prim_attr('pads', self.pads)
|
|
2645
|
+
|
|
2646
|
+
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name, allow_four=True,
|
|
2647
|
+
ret_four=False)
|
|
2648
|
+
self.add_prim_attr('ksize', self.kernel_size)
|
|
2649
|
+
|
|
2650
|
+
self.dilations = _check_positive_int_or_tuple('dilations', dilations, self.name, allow_four=True, ret_four=True)
|
|
2651
|
+
self.add_prim_attr('dilations', dilations)
|
|
2652
|
+
|
|
2653
|
+
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
2654
|
+
self.add_prim_attr('data_format', self.data_format)
|
|
2655
|
+
|
|
2656
|
+
self.deformable_groups = validator.check_positive_int(deformable_groups, 'deformable_groups', self.name)
|
|
2657
|
+
self.add_prim_attr('deformable_groups', self.deformable_groups)
|
|
2658
|
+
|
|
2659
|
+
self.modulated = validator.check_bool(modulated, 'modulated', self.name)
|
|
2660
|
+
self.add_prim_attr('modulated', self.modulated)
|
|
2661
|
+
|
|
2662
|
+
|
|
2663
|
+
class MedianGrad(Primitive):
|
|
2664
|
+
"""
|
|
2665
|
+
Computes gradient for Median operation.
|
|
2666
|
+
|
|
2667
|
+
.. warning::
|
|
2668
|
+
When attr `global_median` is True, the value of Median's second output Tensor `indices` value is meaningless.
|
|
2669
|
+
|
|
2670
|
+
Args:
|
|
2671
|
+
global_median (bool): Whether the output tensor is the global median of all input tensor elements
|
|
2672
|
+
or not in Median operation.
|
|
2673
|
+
axis (int): The dimension need to reduce in Median operation.
|
|
2674
|
+
keep_dims (bool): Whether the output tensor need to retain `axis` dimension or not in Median operation.
|
|
2675
|
+
|
|
2676
|
+
Inputs:
|
|
2677
|
+
- **y_grad** (Tensor) - The gradients of loss to output of Median function.
|
|
2678
|
+
- **x** (Tensor) - The first input is a tensor whose data type is number.
|
|
2679
|
+
The dtype is one of the following: int16, int32, int64, float32, double.
|
|
2680
|
+
- **y** (Tensor) - The first output of Median function, which datatype is same as `x`.
|
|
2681
|
+
- **indices** (Tensor) - The second output of Median function, which datatype is int64.
|
|
2682
|
+
|
|
2683
|
+
Outputs:
|
|
2684
|
+
x_grad - Tensor, has the same shape as the `x`, dtype is double only when dtype of `x` is double.
|
|
2685
|
+
Otherwise, dtype of `x_grad` is float32.
|
|
2686
|
+
|
|
2687
|
+
Raises:
|
|
2688
|
+
TypeError: If dtype of `y_grad` is not the same as `x`.
|
|
2689
|
+
ValueError: If shape of `y_grad` is not the same as `y`.
|
|
2690
|
+
|
|
2691
|
+
Supported Platforms:
|
|
2692
|
+
``Ascend`` ``CPU``
|
|
2693
|
+
"""
|
|
2694
|
+
|
|
2695
|
+
@prim_attr_register
|
|
2696
|
+
def __init__(self, global_median=False, axis=0, keep_dims=False):
|
|
2697
|
+
validator.check_value_type("global_median", global_median, [bool], self.name)
|
|
2698
|
+
self.global_median = global_median
|
|
2699
|
+
if global_median is False:
|
|
2700
|
+
validator.check_value_type("axis", axis, [int], self.name)
|
|
2701
|
+
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
|
|
2702
|
+
self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
|
|
2703
|
+
|
|
2704
|
+
|
|
2705
|
+
class SparseSegmentSumGrad(Primitive):
|
|
2706
|
+
"""
|
|
2707
|
+
Computes gradients for SparseSegmentSumGrad operation.
|
|
2708
|
+
|
|
2709
|
+
Inputs:
|
|
2710
|
+
- **grad** (Tensor) - A tensor.
|
|
2711
|
+
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
|
2712
|
+
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
|
2713
|
+
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
|
2714
|
+
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
|
2715
|
+
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op.
|
|
2716
|
+
|
|
2717
|
+
Outputs:
|
|
2718
|
+
A Tensor. Has the same type as `grad` .
|
|
2719
|
+
Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`.
|
|
2720
|
+
|
|
2721
|
+
Raises:
|
|
2722
|
+
TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
|
|
2723
|
+
TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}.
|
|
2724
|
+
TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64.
|
|
2725
|
+
ValueError: If dimension size of `grad` less than 1.
|
|
2726
|
+
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
|
2727
|
+
ValueError: If dimension size of `output_dim0` is not 0.
|
|
2728
|
+
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
|
2729
|
+
ValueError: If `segment_ids` is not sorted.
|
|
2730
|
+
ValueError: If the last number of `segment_ids` is out of range of grad's first shape.
|
|
2731
|
+
ValueError: If `indices` is bigger than or equal to `output_dim0`.
|
|
2732
|
+
|
|
2733
|
+
Supported Platforms:
|
|
2734
|
+
``GPU``
|
|
2735
|
+
"""
|
|
2736
|
+
__mindspore_signature__ = (
|
|
2737
|
+
sig.make_sig('grad', dtype=sig.sig_dtype.T1),
|
|
2738
|
+
sig.make_sig('indices', dtype=sig.sig_dtype.T),
|
|
2739
|
+
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
|
|
2740
|
+
sig.make_sig('output_dim0', dtype=sig.sig_dtype.T)
|
|
2741
|
+
)
|
|
2742
|
+
|
|
2743
|
+
@prim_attr_register
|
|
2744
|
+
def __init__(self):
|
|
2745
|
+
"""Initialize SparseSegmentSumGrad"""
|
|
2746
|
+
self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
|
|
2747
|
+
|
|
2748
|
+
|
|
2749
|
+
class SparseSegmentSqrtNGrad(Primitive):
|
|
2750
|
+
"""
|
|
2751
|
+
Computes gradients for SparseSegmentSqrtNGrad operation.
|
|
2752
|
+
|
|
2753
|
+
Inputs:
|
|
2754
|
+
- **x** (Tensor) - A tensor. It's rank must be more than or equal to one.
|
|
2755
|
+
- **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
|
|
2756
|
+
types: int32, int64. Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
|
2757
|
+
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one
|
|
2758
|
+
of the following types: int32, int64. Values should be sorted and can be repeated. The shape should
|
|
2759
|
+
be :math:`(N,)`.
|
|
2760
|
+
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSqrtN op.
|
|
2761
|
+
|
|
2762
|
+
Outputs:
|
|
2763
|
+
A Tensor. Has the same type as `x` .
|
|
2764
|
+
Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
|
|
2765
|
+
|
|
2766
|
+
Raises:
|
|
2767
|
+
TypeError: If `x` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
|
|
2768
|
+
TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
|
|
2769
|
+
TypeError: If the dtype of `indices` is not int32.
|
|
2770
|
+
TypeError: If the dtype of `segment_ids` is not int32.
|
|
2771
|
+
TypeError: If the dtype of `output_dim0` is not int32.
|
|
2772
|
+
ValueError: If dimension size of `x` is less than 1.
|
|
2773
|
+
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
|
2774
|
+
ValueError: If dimension size of `output_dim0` is not 0.
|
|
2775
|
+
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
|
2776
|
+
ValueError: If `segment_ids` is not sorted.
|
|
2777
|
+
ValueError: If the last number of `segment_ids` is out of range of x's first shape.
|
|
2778
|
+
ValueError: If `indices` is bigger than or equal to `output_dim0`.
|
|
2779
|
+
|
|
2780
|
+
Supported Platforms:
|
|
2781
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2782
|
+
"""
|
|
2783
|
+
|
|
2784
|
+
@prim_attr_register
|
|
2785
|
+
def __init__(self):
|
|
2786
|
+
"""Initialize SparseSegmentSqrtNGrad"""
|
|
2787
|
+
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
|
|
2788
|
+
|
|
2789
|
+
|
|
2790
|
+
class SparseSliceGrad(Primitive):
|
|
2791
|
+
r"""
|
|
2792
|
+
Computes gradients for SparseSlice operation.
|
|
2793
|
+
|
|
2794
|
+
Inputs:
|
|
2795
|
+
- **backprop_val_grad** (Tensor) - A 1D Tensor.
|
|
2796
|
+
The shape should be :math:`(N,)`.
|
|
2797
|
+
- **indices** (Tensor) - A 2D Tensor (N x R matrix) of type int64. The indices of the SparseTensor.
|
|
2798
|
+
Support int64, each element value should be a non-negative int number. This tensor should be sorted.
|
|
2799
|
+
The shape is :math:`(N, R)`.
|
|
2800
|
+
- **start** (Tensor) - A 1D Tensor of type int64, represents the start of the indices.
|
|
2801
|
+
The shape should be :math:`(R,)`.
|
|
2802
|
+
- **new_indices** (Tensor) - A 2D Tensor (N x C matrix) of type int64. The indices of the SparseTensor.
|
|
2803
|
+
Support int64, each element value should be a non-negative int number. This tensor should be sorted.
|
|
2804
|
+
The shape is :math:`(N, C)`.
|
|
2805
|
+
|
|
2806
|
+
Outputs:
|
|
2807
|
+
- *y_grad_val: A Tensor. Has the same type as `backprop_val_grad`.
|
|
2808
|
+
Has the same number as `indices`.
|
|
2809
|
+
|
|
2810
|
+
Raises:
|
|
2811
|
+
TypeError: If the dtype of `indices`, `start`, `new_indices` are not int64.
|
|
2812
|
+
ValueError: If `indices`, `new_indices` are not 2-D tensor.
|
|
2813
|
+
ValueError: If `backprop_val_grad`, `start` is not a 1-D tensor.
|
|
2814
|
+
ValueError: If the number of `backprop_val_grad` is not corresponding to the number of `new_indices`.
|
|
2815
|
+
ValueError: If the shape of `indices[1]` is not corresponding to `start[1]`.
|
|
2816
|
+
ValueError: If the shape of `indices[1]` is not corresponding to `new_indices[1]`.
|
|
2817
|
+
RuntimeError: If the `backprop_val_grad` is not all backpropagated, because `indices` or `new_indices`
|
|
2818
|
+
is not sorted.
|
|
2819
|
+
|
|
2820
|
+
Supported Platforms:
|
|
2821
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2822
|
+
Examples:
|
|
2823
|
+
>>> backprop_val_grad = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
|
|
2824
|
+
>>> indices = Tensor(np.array([[0, 0], [0, 2], [1, 2], [1, 3], [2, 3], [2, 4]]).astype(np.int64))
|
|
2825
|
+
>>> start = Tensor(np.array([0, 0]).astype(np.int64))
|
|
2826
|
+
>>> new_indices = Tensor(np.array([[0, 2], [1, 2], [1, 3], [2, 4]]).astype(np.int64))
|
|
2827
|
+
>>> grad = SparseSliceGrad()
|
|
2828
|
+
>>> output = grad(backprop_val_grad, indices, start, new_indices)
|
|
2829
|
+
>>> print(output)
|
|
2830
|
+
[0 1 2 3 0 4]
|
|
2831
|
+
"""
|
|
2832
|
+
|
|
2833
|
+
@prim_attr_register
|
|
2834
|
+
def __init__(self):
|
|
2835
|
+
"""Initialize SparseSliceGrad."""
|
|
2836
|
+
self.init_prim_io_names(inputs=['backprop_val_grad', 'indices', 'start', 'new_indices'], outputs=['y_grad'])
|
|
2837
|
+
|
|
2838
|
+
|
|
2839
|
+
class FractionalMaxPoolGradWithFixedKsize(Primitive):
|
|
2840
|
+
"""
|
|
2841
|
+
Computes the gradients of FractionalMaxPoolWithFixedKsize.
|
|
2842
|
+
|
|
2843
|
+
Args:
|
|
2844
|
+
data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
|
|
2845
|
+
|
|
2846
|
+
Inputs:
|
|
2847
|
+
- **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be int32 or int64.
|
|
2848
|
+
- **out_backprop** (Tensor) - The gradients with respect to the output of FractionalMaxPoolWithFixedKsize
|
|
2849
|
+
function. Tensor with data format "NCHW", whose data type is float16, float32, float64, int32 or int64.
|
|
2850
|
+
- **argmax** (Tensor) - The second output of FractionalMaxPoolWithFixedKsize function, whose data
|
|
2851
|
+
type is int64.
|
|
2852
|
+
|
|
2853
|
+
Outputs:
|
|
2854
|
+
- **y** (Tensor) - Tensor, with the same shape as `origin_input`, and the same data type as
|
|
2855
|
+
the input `out_backprop`.
|
|
2856
|
+
|
|
2857
|
+
Raises:
|
|
2858
|
+
TypeError: If data type of `out_backprop` is not one of the following: float16, float32, float64, int32, int64.
|
|
2859
|
+
TypeError: If data type of `argmax` is not int64.
|
|
2860
|
+
ValueError: If the shape of `out_backprop` and `argmax` is not equal.
|
|
2861
|
+
ValueError: If the first dimension size of `origin_input` and `out_backprop` is not equal.
|
|
2862
|
+
ValueError: If the second dimension size of `origin_input` and `out_backprop` is not equal.
|
|
2863
|
+
|
|
2864
|
+
Supported Platforms:
|
|
2865
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2866
|
+
"""
|
|
2867
|
+
|
|
2868
|
+
@prim_attr_register
|
|
2869
|
+
def __init__(self, data_format="NCHW"):
|
|
2870
|
+
self.data_format = validator.check_string(data_format, ['NCHW'], 'data_format', self.name)
|
|
2871
|
+
self.add_prim_attr("data_format", self.data_format)
|
|
2872
|
+
self.init_prim_io_names(inputs=['origin_input', 'out_backprop', 'argmax'], outputs=['y'])
|
|
2873
|
+
|
|
2874
|
+
|
|
2875
|
+
class AffineGridGrad(Primitive):
|
|
2876
|
+
r"""
|
|
2877
|
+
Computes gradients for AffineGrid operation.
|
|
2878
|
+
|
|
2879
|
+
Args:
|
|
2880
|
+
align_corners (bool): if True, consider -1 and 1 to refer to the centers
|
|
2881
|
+
of the corner pixels rather than the image corners. Default: ``False``.
|
|
2882
|
+
|
|
2883
|
+
Inputs:
|
|
2884
|
+
- **y_grad** (Tensor) - Data type must be float16 or float32.
|
|
2885
|
+
- **x_size** (tuple) - Data type must be int32 or int64.
|
|
2886
|
+
|
|
2887
|
+
Outputs:
|
|
2888
|
+
Tensor, with data type same as `y_grad`.
|
|
2889
|
+
|
|
2890
|
+
Supported Platforms:
|
|
2891
|
+
``CPU``
|
|
2892
|
+
|
|
2893
|
+
Examples:
|
|
2894
|
+
>>> import mindspore.ops.operations._grad_ops as _grad_ops
|
|
2895
|
+
>>> affinegridgrad = _grad_ops.AffineGridGrad()
|
|
2896
|
+
>>> y_grad = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
|
|
2897
|
+
>>> x_size = (1, 2, 2, 2)
|
|
2898
|
+
>>> x_grad = affinegridgrad(y_grad, x_size)
|
|
2899
|
+
>>> print(x_grad)
|
|
2900
|
+
[[[0. 0. 4.]
|
|
2901
|
+
[0. 0. 4.]]]
|
|
2902
|
+
"""
|
|
2903
|
+
|
|
2904
|
+
@prim_attr_register
|
|
2905
|
+
def __init__(self, align_corners=False):
|
|
2906
|
+
"""Initialize AffineGridGrad."""
|
|
2907
|
+
validator.check_value_type("align_corners", align_corners, [bool], self.name)
|
|
2908
|
+
self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
|
|
2909
|
+
|
|
2910
|
+
|
|
2911
|
+
|
|
2912
|
+
class GluGrad(Primitive):
|
|
2913
|
+
"""
|
|
2914
|
+
Computes grad for Glu operation.
|
|
2915
|
+
"""
|
|
2916
|
+
|
|
2917
|
+
@prim_attr_register
|
|
2918
|
+
def __init__(self, axis):
|
|
2919
|
+
self.add_prim_attr("cust_aicpu", self.name)
|
|
2920
|
+
self.init_prim_io_names(inputs=["grads", "x"], outputs=["y"])
|
|
2921
|
+
validator.check_value_type("axis", axis, [int], self.name)
|
|
2922
|
+
|
|
2923
|
+
|
|
2924
|
+
class MapTensorGetGrad(Primitive):
|
|
2925
|
+
"""
|
|
2926
|
+
Computes gradients for MapTensorGet operation.
|
|
2927
|
+
|
|
2928
|
+
Inputs:
|
|
2929
|
+
- **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
|
|
2930
|
+
- **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
|
|
2931
|
+
- **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
|
|
2932
|
+
- **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
|
|
2933
|
+
|
|
2934
|
+
Outputs:
|
|
2935
|
+
- **output** (MapTensor) - MapTensor with grad values.
|
|
2936
|
+
"""
|
|
2937
|
+
@prim_attr_register
|
|
2938
|
+
def __init__(self):
|
|
2939
|
+
"""Initialize MapTensorGetGrad"""
|
|
2940
|
+
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
|
|
2941
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2942
|
+
|
|
2943
|
+
|
|
2944
|
+
class ResizeV2Grad(Primitive):
|
|
2945
|
+
r"""
|
|
2946
|
+
Calculates the gradient of ResizeV2 operation.
|
|
2947
|
+
|
|
2948
|
+
Supported Platforms:
|
|
2949
|
+
``CPU``
|
|
2950
|
+
"""
|
|
2951
|
+
|
|
2952
|
+
@prim_attr_register
|
|
2953
|
+
def __init__(self, coordinate_transformation_mode="half_pixel", mode="nearest"):
|
|
2954
|
+
"""Initialize ResizeV2Grad."""
|
|
2955
|
+
self.init_prim_io_names(inputs=["grads", "roi", "scales", "original_size"], outputs=["y"])
|
|
2956
|
+
self.add_prim_attr("nearest_mode", "floor")
|
|
2957
|
+
self.add_prim_attr("cubic_coeff_a", -0.75)
|
|
2958
|
+
validator.check_value_type(
|
|
2959
|
+
"coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
|
|
2960
|
+
validator.check_string(coordinate_transformation_mode,
|
|
2961
|
+
["align_corners", "half_pixel"], "coordinate_transformation_mode", self.name)
|
|
2962
|
+
validator.check_value_type("mode", mode, [str], self.name)
|
|
2963
|
+
validator.check_string(mode, ["nearest", "linear", "cubic"], "mode", self.name)
|
|
2964
|
+
|
|
2965
|
+
|
|
2966
|
+
class WKVGrad(Primitive):
|
|
2967
|
+
r"""
|
|
2968
|
+
Calculates the gradient of WKV operation.
|
|
2969
|
+
|
|
2970
|
+
Supported Platforms:
|
|
2971
|
+
``Ascend``
|
|
2972
|
+
"""
|
|
2973
|
+
|
|
2974
|
+
@prim_attr_register
|
|
2975
|
+
def __init__(self):
|
|
2976
|
+
"""Initialize WKVGrad."""
|
|
2977
|
+
self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
|
|
2978
|
+
outputs=["gw", "gu", "gk", "gv"])
|