mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2544 @@
|
|
|
1
|
+
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Inner operators."""
|
|
17
|
+
from types import FunctionType, MethodType
|
|
18
|
+
from collections.abc import Iterable
|
|
19
|
+
import os
|
|
20
|
+
import weakref
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from mindspore.common import Tensor
|
|
24
|
+
from mindspore.common._stub_tensor import StubTensor
|
|
25
|
+
from mindspore.ops import composite as C
|
|
26
|
+
from mindspore.ops.operations.array_ops import Cast
|
|
27
|
+
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
28
|
+
from mindspore.ops import signature as sig
|
|
29
|
+
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
30
|
+
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
31
|
+
_run_op, _check_contains_variable
|
|
32
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
33
|
+
from mindspore._c_expression import typing, HookType
|
|
34
|
+
from mindspore import _checkparam as validator
|
|
35
|
+
from mindspore.common import dtype as mstype
|
|
36
|
+
from mindspore.common.parameter import Parameter
|
|
37
|
+
from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
|
|
38
|
+
from mindspore.common.api import _pynative_executor
|
|
39
|
+
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
40
|
+
from mindspore import ops
|
|
41
|
+
from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
|
|
42
|
+
PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
|
|
43
|
+
|
|
44
|
+
# Bit operation
|
|
45
|
+
bit_and = bit_and()
|
|
46
|
+
bit_or = bit_or()
|
|
47
|
+
bit_xor = Primitive("bit_xor")
|
|
48
|
+
bit_left_shift = Primitive("bit_left_shift")
|
|
49
|
+
bit_right_shift = Primitive("bit_right_shift")
|
|
50
|
+
# String operation
|
|
51
|
+
string_lt = Primitive("string_lt")
|
|
52
|
+
string_gt = Primitive("string_gt")
|
|
53
|
+
string_le = Primitive("string_le")
|
|
54
|
+
string_ge = Primitive("string_ge")
|
|
55
|
+
string_not = Primitive("string_not")
|
|
56
|
+
string_in = Primitive("string_in")
|
|
57
|
+
string_mul = Primitive("string_mul")
|
|
58
|
+
string_getitem = Primitive("string_getitem")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Generator(Primitive):
|
|
62
|
+
r"""
|
|
63
|
+
Manage the state of random number generation.
|
|
64
|
+
|
|
65
|
+
Inputs:
|
|
66
|
+
- **cmd** (int) : operation to be executed.
|
|
67
|
+
- **inputs** (tuple[tensor]) : inputs for the operation.
|
|
68
|
+
|
|
69
|
+
Outputs:
|
|
70
|
+
- **seed** (Tensor): Seed for the random number generation algorithm.
|
|
71
|
+
- **offset** (Tensor): Offset of the random number sequence.
|
|
72
|
+
- **state** (Tensor): State tensor, can be used to restore current state.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
@prim_attr_register
|
|
76
|
+
def __init__(self):
|
|
77
|
+
self.add_prim_attr("side_effect_mem", True)
|
|
78
|
+
|
|
79
|
+
def __call__(self, cmd, inputs):
|
|
80
|
+
if cmd == 0: # step cmd
|
|
81
|
+
return inputs[0], inputs[1]
|
|
82
|
+
return super().__call__(cmd, inputs)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Quant(PrimitiveWithInfer):
|
|
86
|
+
r"""
|
|
87
|
+
Returns the quantized value of input_x.
|
|
88
|
+
|
|
89
|
+
If `sqrt_mode` is False:
|
|
90
|
+
|
|
91
|
+
.. math::
|
|
92
|
+
y = round(scale * x + offset)
|
|
93
|
+
|
|
94
|
+
If `sqrt_mode` is True:
|
|
95
|
+
|
|
96
|
+
.. math::
|
|
97
|
+
y = round(scale * x * scale + offset)
|
|
98
|
+
|
|
99
|
+
Note:
|
|
100
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
scale (float) : Specifies the scaling ratio.
|
|
104
|
+
offset (float): Specifies the offset.
|
|
105
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
106
|
+
round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
|
|
107
|
+
Default: "Round".
|
|
108
|
+
|
|
109
|
+
Inputs:
|
|
110
|
+
- **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
|
|
111
|
+
|
|
112
|
+
Outputs:
|
|
113
|
+
- Tensor: The quantized output tensor of type mindspore.int8.
|
|
114
|
+
|
|
115
|
+
Examples:
|
|
116
|
+
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
|
|
117
|
+
>>> quant = ops.Quant(80.0, 0.0, False, "Round")
|
|
118
|
+
>>> y = quant(input_x)
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
@prim_attr_register
|
|
122
|
+
def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
|
|
123
|
+
self.scale = validator.check_value_type("scale", scale, [float], self.name)
|
|
124
|
+
self.offset = validator.check_value_type("offset", offset, [float], self.name)
|
|
125
|
+
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
126
|
+
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
|
|
127
|
+
"round_mode", self.name)
|
|
128
|
+
self.add_prim_attr("dst_type", mstype.int8)
|
|
129
|
+
|
|
130
|
+
def infer_shape(self, x_shape):
|
|
131
|
+
return x_shape
|
|
132
|
+
|
|
133
|
+
def infer_dtype(self, x_type):
|
|
134
|
+
validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
|
|
135
|
+
validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
|
|
136
|
+
return self.get_attr_dict()['dst_type']
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Lamb(PrimitiveWithInfer):
|
|
140
|
+
r"""
|
|
141
|
+
LAMB optimizer algorithm.
|
|
142
|
+
|
|
143
|
+
The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
|
|
144
|
+
<https://arxiv.org/abs/1904.00962>`_.
|
|
145
|
+
|
|
146
|
+
Inputs:
|
|
147
|
+
- **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
|
148
|
+
any number of additional dimensions. The data type can be float16 or float32.
|
|
149
|
+
- **m** (Tensor) - The 1st moment vector in the updating formula,
|
|
150
|
+
the shape and data type value should be the same as `var`.
|
|
151
|
+
- **v** (Tensor) - the 2nd moment vector in the updating formula,
|
|
152
|
+
the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
|
|
153
|
+
- **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
|
|
154
|
+
the data type value should be the same as `var`.
|
|
155
|
+
- **beta1** (float) - The exponential decay rate for the 1st moment estimations,
|
|
156
|
+
the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
|
|
157
|
+
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
|
|
158
|
+
the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
|
|
159
|
+
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
|
|
160
|
+
- **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
|
|
161
|
+
Default: 0.0.
|
|
162
|
+
- **global_step** (Tensor) - Tensor to record current global step.
|
|
163
|
+
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
|
|
164
|
+
|
|
165
|
+
Outputs:
|
|
166
|
+
Tensor, the updated parameters.
|
|
167
|
+
|
|
168
|
+
- **var** (Tensor) - The same shape and data type as `var`.
|
|
169
|
+
|
|
170
|
+
Supported Platforms:
|
|
171
|
+
``Ascend````GPU``
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
@prim_attr_register
|
|
175
|
+
def __init__(self):
|
|
176
|
+
"""Initialize Lamb."""
|
|
177
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
178
|
+
|
|
179
|
+
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
|
|
180
|
+
epsilon_shape, decay_shape, global_step_shape, gradient_shape):
|
|
181
|
+
validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
|
|
182
|
+
validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
|
|
183
|
+
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, validator.EQ, self.name)
|
|
184
|
+
return var_shape
|
|
185
|
+
|
|
186
|
+
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
|
187
|
+
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
|
|
188
|
+
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
|
|
189
|
+
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
|
190
|
+
|
|
191
|
+
args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
|
|
192
|
+
"epsilon": epsilon_dtype}
|
|
193
|
+
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
|
194
|
+
return var_dtype
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class Dequant(PrimitiveWithInfer):
|
|
198
|
+
r"""
|
|
199
|
+
Returns the dequantized value of input_x.
|
|
200
|
+
This operation will do ReLU to the dequantized value if `relu_flag` is True.
|
|
201
|
+
|
|
202
|
+
If `sqrt_mode` is False:
|
|
203
|
+
|
|
204
|
+
.. math::
|
|
205
|
+
y = x * deq\_scale
|
|
206
|
+
|
|
207
|
+
If `sqrt_mode` is True:
|
|
208
|
+
|
|
209
|
+
.. math::
|
|
210
|
+
y = x * deq\_scale * deq\_scale
|
|
211
|
+
|
|
212
|
+
Note:
|
|
213
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
217
|
+
relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
|
|
218
|
+
|
|
219
|
+
Inputs:
|
|
220
|
+
- **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
|
|
221
|
+
- **deq_scale** (Tensor) : Specifies the scaling ratio.
|
|
222
|
+
Data type must be mindspore.float16 or mindspore.uint64
|
|
223
|
+
|
|
224
|
+
Outputs:
|
|
225
|
+
- Tensor: The quantized output tensor of type mindspore.float16.
|
|
226
|
+
|
|
227
|
+
Examples:
|
|
228
|
+
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
|
|
229
|
+
>>> dequant = ops.Dequant(False, False)
|
|
230
|
+
>>> y = dequant(input_x)
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
@prim_attr_register
|
|
234
|
+
def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16):
|
|
235
|
+
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
236
|
+
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
|
|
237
|
+
self.dtype = dtype
|
|
238
|
+
|
|
239
|
+
def infer_shape(self, x_shape, deq_scale_shape):
|
|
240
|
+
return x_shape
|
|
241
|
+
|
|
242
|
+
def infer_dtype(self, x_type, deq_scale_type):
|
|
243
|
+
validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
|
|
244
|
+
validator.check_type_name("x", x_type, [mstype.int32], self.name)
|
|
245
|
+
validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
|
|
246
|
+
return mstype.float16
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class AntiQuant(Primitive):
|
|
250
|
+
r"""
|
|
251
|
+
Returns the antiquantized value of input_x.
|
|
252
|
+
|
|
253
|
+
If `sqrt_mode` is False:
|
|
254
|
+
|
|
255
|
+
.. math::
|
|
256
|
+
y = scale * (x + offset)
|
|
257
|
+
|
|
258
|
+
If `sqrt_mode` is True:
|
|
259
|
+
|
|
260
|
+
.. math::
|
|
261
|
+
y = scale * scale * (x + offset)
|
|
262
|
+
|
|
263
|
+
Note:
|
|
264
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
scale (float) : Specifies the scaling ratio.
|
|
268
|
+
offset (float): Specifies the offset.
|
|
269
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
270
|
+
|
|
271
|
+
Inputs:
|
|
272
|
+
- **input_x** (Tensor) : Input tensor. Must be mindspore.int8.
|
|
273
|
+
|
|
274
|
+
Outputs:
|
|
275
|
+
- Tensor: The antiquantized output tensor of type mindspore.float32.
|
|
276
|
+
|
|
277
|
+
Examples:
|
|
278
|
+
>>> from mindspore.ops.operations._inner_ops import AntiQuant
|
|
279
|
+
>>> input_x = Tensor([50.0, 20.0], mstype.int8)
|
|
280
|
+
>>> antiquant = AntiQuant(2.0, 1.0, False)
|
|
281
|
+
>>> y = antiquant(input_x)
|
|
282
|
+
>>> print(y)
|
|
283
|
+
[102. 42.]
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
@prim_attr_register
|
|
287
|
+
def __init__(self, sqrt_mode=False, dtype=mstype.float16):
|
|
288
|
+
super().__init__("AntiQuant")
|
|
289
|
+
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
290
|
+
self.dtype = dtype
|
|
291
|
+
|
|
292
|
+
self.init_prim_io_names(inputs=['x', 'scale', 'offset'],
|
|
293
|
+
outputs=['y'])
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class MatrixDiag(PrimitiveWithInfer):
|
|
297
|
+
"""
|
|
298
|
+
Returns a batched diagonal tensor with a given batched diagonal values.
|
|
299
|
+
|
|
300
|
+
Inputs:
|
|
301
|
+
- **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
|
|
302
|
+
types: float32, float16, int32, int8, and uint8.
|
|
303
|
+
- **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must be greater than or equal to 2 and
|
|
304
|
+
it's last dimension must be equal to the second to last dimension.
|
|
305
|
+
|
|
306
|
+
Outputs:
|
|
307
|
+
Tensor, has the same type and shape as input `assist`.
|
|
308
|
+
|
|
309
|
+
Examples:
|
|
310
|
+
>>> x = Tensor(np.array([1, -1]), mstype.float32)
|
|
311
|
+
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
|
|
312
|
+
>>> matrix_diag = ops.MatrixDiag()
|
|
313
|
+
>>> result = matrix_diag(x, assist)
|
|
314
|
+
>>> print(result)
|
|
315
|
+
[[[-12. 11.]
|
|
316
|
+
[-10. 9.]]
|
|
317
|
+
[[ -8. 7.]
|
|
318
|
+
[ -6. 5.]]
|
|
319
|
+
[[ -4. 3.]
|
|
320
|
+
[ -2. 1.]]]
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
@prim_attr_register
|
|
324
|
+
def __init__(self):
|
|
325
|
+
"""Initialize MatrixDiag"""
|
|
326
|
+
|
|
327
|
+
def infer_dtype(self, x_dtype, assist_dtype):
|
|
328
|
+
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
329
|
+
args = {"x": x_dtype, "assist": assist_dtype}
|
|
330
|
+
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
|
331
|
+
return x_dtype
|
|
332
|
+
|
|
333
|
+
def infer_shape(self, x_shape, assist_shape):
|
|
334
|
+
validator.check_int(len(assist_shape), 2, validator.GE, "assist rank", self.name)
|
|
335
|
+
validator.check('rank of x', len(x_shape) + 1,
|
|
336
|
+
'rank of assist', len(assist_shape), validator.LE, self.name)
|
|
337
|
+
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
|
|
338
|
+
assist_shape[-1], validator.EQ, self.name)
|
|
339
|
+
|
|
340
|
+
r_end_dim = -len(x_shape)
|
|
341
|
+
r_idx = -1
|
|
342
|
+
while r_idx >= r_end_dim:
|
|
343
|
+
if x_shape[r_idx] != 1:
|
|
344
|
+
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
|
|
345
|
+
assist_shape[r_idx - 1], assist_shape[r_idx - 1], validator.EQ, self.name)
|
|
346
|
+
r_idx = r_idx - 1
|
|
347
|
+
|
|
348
|
+
return assist_shape
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class MatrixDiagPart(PrimitiveWithInfer):
|
|
352
|
+
r"""
|
|
353
|
+
Returns the batched diagonal part of a batched tensor.
|
|
354
|
+
|
|
355
|
+
Inputs:
|
|
356
|
+
- **x** (Tensor) - The batched tensor. It can be one of the following data types:
|
|
357
|
+
float32, float16, int32, int8, uint8.
|
|
358
|
+
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
|
|
359
|
+
|
|
360
|
+
Outputs:
|
|
361
|
+
Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
|
|
362
|
+
|
|
363
|
+
Examples:
|
|
364
|
+
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
|
365
|
+
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
|
|
366
|
+
>>> matrix_diag_part = ops.MatrixDiagPart()
|
|
367
|
+
>>> result = matrix_diag_part(x, assist)
|
|
368
|
+
>>> print(result)
|
|
369
|
+
[[12., -9.], [8., -5.], [4., -1.]]
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
@prim_attr_register
|
|
373
|
+
def __init__(self):
|
|
374
|
+
"""Initialize MatrixDiagPart"""
|
|
375
|
+
|
|
376
|
+
def infer_dtype(self, x_dtype, assist_dtype):
|
|
377
|
+
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
378
|
+
args = {"x": x_dtype, "assist": assist_dtype}
|
|
379
|
+
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
|
380
|
+
return x_dtype
|
|
381
|
+
|
|
382
|
+
def infer_shape(self, x_shape, assist_shape):
|
|
383
|
+
validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
|
|
384
|
+
validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
|
|
385
|
+
|
|
386
|
+
if assist_shape[-2] < assist_shape[-1]:
|
|
387
|
+
out_shape = assist_shape[:-1]
|
|
388
|
+
else:
|
|
389
|
+
out_shape = assist_shape[:-2] + assist_shape[-1:]
|
|
390
|
+
return out_shape
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class MatrixSetDiag(PrimitiveWithInfer):
|
|
394
|
+
r"""
|
|
395
|
+
Modifies the batched diagonal part of a batched tensor.
|
|
396
|
+
|
|
397
|
+
Inputs:
|
|
398
|
+
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
|
|
399
|
+
float32, float16, int32, int8, uint8.
|
|
400
|
+
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
|
|
401
|
+
- **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
|
|
402
|
+
|
|
403
|
+
Outputs:
|
|
404
|
+
Tensor, data type same as input `x`. The shape same as `x`.
|
|
405
|
+
|
|
406
|
+
Examples:
|
|
407
|
+
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
|
408
|
+
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
|
|
409
|
+
>>> matrix_set_diag = ops.MatrixSetDiag()
|
|
410
|
+
>>> result = matrix_set_diag(x, diagonal)
|
|
411
|
+
>>> print(result)
|
|
412
|
+
[[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
|
|
413
|
+
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
@prim_attr_register
|
|
417
|
+
def __init__(self):
|
|
418
|
+
"""Initialize MatrixSetDiag"""
|
|
419
|
+
|
|
420
|
+
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
|
|
421
|
+
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
422
|
+
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
|
|
423
|
+
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
|
424
|
+
return x_dtype
|
|
425
|
+
|
|
426
|
+
def infer_shape(self, x_shape, diagonal_shape, assist_shape):
|
|
427
|
+
validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
|
|
428
|
+
validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
|
|
429
|
+
|
|
430
|
+
if x_shape[-2] < x_shape[-1]:
|
|
431
|
+
validator.check("diagonal shape", diagonal_shape, "x shape excluding the last dimension",
|
|
432
|
+
x_shape[:-1], validator.EQ, self.name)
|
|
433
|
+
else:
|
|
434
|
+
validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
|
|
435
|
+
x_shape[:-2] + x_shape[-1:], validator.EQ, self.name)
|
|
436
|
+
|
|
437
|
+
return assist_shape
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class ConfusionMulGrad(PrimitiveWithInfer):
|
|
441
|
+
"""
|
|
442
|
+
`output0` is the dot product result of input0 and input1.
|
|
443
|
+
|
|
444
|
+
`output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
|
|
448
|
+
Default:(), reduce all dimensions. Only constant value is allowed.
|
|
449
|
+
keep_dims (bool):
|
|
450
|
+
|
|
451
|
+
- If true, keep these reduced dimensions and the length as 1.
|
|
452
|
+
- If false, don't keep these dimensions. Default:False.
|
|
453
|
+
|
|
454
|
+
Inputs:
|
|
455
|
+
- **input_0** (Tensor) - The input Tensor.
|
|
456
|
+
- **input_1** (Tensor) - The input Tensor.
|
|
457
|
+
- **input_2** (Tensor) - The input Tensor.
|
|
458
|
+
|
|
459
|
+
Outputs:
|
|
460
|
+
- **output_0** (Tensor) - The same shape as `input0`.
|
|
461
|
+
- **output_1** (Tensor)
|
|
462
|
+
|
|
463
|
+
- If axis is (), and keep_dims is false, the output is a 0-D array representing
|
|
464
|
+
the sum of all elements in the input array.
|
|
465
|
+
- If axis is int, set as 2, and keep_dims is false,
|
|
466
|
+
the shape of output is :math:`(x_1,x_3,...,x_R)`.
|
|
467
|
+
- If axis is tuple(int), set as (2,3), and keep_dims is false,
|
|
468
|
+
the shape of output is :math:`(x_1,x_4,...x_R)`.
|
|
469
|
+
|
|
470
|
+
Examples:
|
|
471
|
+
>>> confusion_mul_grad = ops.ConfusionMulGrad()
|
|
472
|
+
>>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
|
|
473
|
+
>>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
|
|
474
|
+
>>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
|
|
475
|
+
>>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
|
|
476
|
+
output_0:
|
|
477
|
+
[[ 3. 1. 0.]
|
|
478
|
+
[-6. 2. -2.]]
|
|
479
|
+
output_1:
|
|
480
|
+
-3.0
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
@prim_attr_register
|
|
484
|
+
def __init__(self, axis=(), keep_dims=False):
|
|
485
|
+
self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
|
|
486
|
+
self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
|
|
487
|
+
self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
|
|
488
|
+
|
|
489
|
+
def infer_shape(self, input0_shape, input1_shape, input2_shape):
|
|
490
|
+
outshape0 = input0_shape
|
|
491
|
+
outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
|
|
492
|
+
return outshape0, outshape1
|
|
493
|
+
|
|
494
|
+
def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
|
|
495
|
+
validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
|
|
496
|
+
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
|
|
497
|
+
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
|
|
498
|
+
return input0_dtype, input1_dtype
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
class ConvertToDynamic(PrimitiveWithCheck):
|
|
502
|
+
"""
|
|
503
|
+
This op is used for dynamic rank testing. Its inferred shape will be unknown
|
|
504
|
+
during compile time, so that its output will appear to be dynamically ranked.
|
|
505
|
+
The input will not be altered in any way. Put this operator before the operator
|
|
506
|
+
being tested for dynamic rank support.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
is_dynamic_rank (bool): If true, convert to dynamic rank.
|
|
510
|
+
If false, convert to dynamic shape. Default: ``False``.
|
|
511
|
+
|
|
512
|
+
Inputs:
|
|
513
|
+
- **input** (Tensor) - The tensor used for testing.
|
|
514
|
+
|
|
515
|
+
Outputs:
|
|
516
|
+
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
517
|
+
|
|
518
|
+
Supported Platforms:
|
|
519
|
+
``CPU``
|
|
520
|
+
|
|
521
|
+
Examples:
|
|
522
|
+
>>> import mindspore as ms
|
|
523
|
+
>>> import mindspore.nn as nn
|
|
524
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
525
|
+
>>> from mindspore.ops import operations as P
|
|
526
|
+
>>> class TestDynamicNet(nn.Cell):
|
|
527
|
+
>>> def __init__(self):
|
|
528
|
+
>>> super(TestDynamicNet, self).__init__()
|
|
529
|
+
>>> self.convert_to_dynamic = inner.ConvertToDynamic()
|
|
530
|
+
>>> # suppose we are testing Reshape op
|
|
531
|
+
>>> self.reshape = P.Reshape()
|
|
532
|
+
>>>
|
|
533
|
+
>>> def construct(self, input, new_shape):
|
|
534
|
+
>>> dynamic_input = self.convert_to_dynamic(input)
|
|
535
|
+
>>> reshaped_input = self.reshape(dynamic_input, new_shape)
|
|
536
|
+
>>>
|
|
537
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
|
|
538
|
+
>>> input = Tensor(np.array([0, 1, 2, 3])
|
|
539
|
+
>>> new_shape = (2, 2)
|
|
540
|
+
>>> net = TestDynamicNet()
|
|
541
|
+
>>> output = net(input, new_shape)
|
|
542
|
+
>>> print(output)
|
|
543
|
+
[[0, 1], [2, 3]
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
@prim_attr_register
|
|
547
|
+
def __init__(self, is_dynamic_rank=False):
|
|
548
|
+
validator.check_value_type('is_dynamic_rank', is_dynamic_rank, [bool], self.name)
|
|
549
|
+
self.init_prim_io_names(inputs=["input"], outputs=["output"])
|
|
550
|
+
|
|
551
|
+
def check_shape(self, input_shape):
|
|
552
|
+
validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
|
|
553
|
+
|
|
554
|
+
def check_dtype(self, input_dtype):
|
|
555
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class GpuConvertToDynamicShape(PrimitiveWithCheck):
|
|
559
|
+
"""
|
|
560
|
+
This op is used for dynamic shape testing. Its inferred shape will be unknown
|
|
561
|
+
during compile time, so that its output will appear to be dynamically shaped.
|
|
562
|
+
The input will not be altered in any way. Put this operator before the operator
|
|
563
|
+
being tested for dynamic shape support.
|
|
564
|
+
|
|
565
|
+
Inputs:
|
|
566
|
+
- **input** (Tensor) - The tensor used for testing.
|
|
567
|
+
|
|
568
|
+
Outputs:
|
|
569
|
+
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
570
|
+
|
|
571
|
+
Examples:
|
|
572
|
+
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
|
|
573
|
+
>>> import mindspore as ms
|
|
574
|
+
>>> import mindspore.nn as nn
|
|
575
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
576
|
+
>>> from mindspore.ops import operations as P
|
|
577
|
+
>>> class TestDynamicShapeReshapeNet(nn.Cell):
|
|
578
|
+
>>> def __init__(self):
|
|
579
|
+
>>> super(TestDynamicShapeReshapeNet, self).__init__()
|
|
580
|
+
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
|
581
|
+
>>> # suppose we are testing Reshape op
|
|
582
|
+
>>> self.reshape = P.Reshape()
|
|
583
|
+
>>>
|
|
584
|
+
>>> def construct(self, input, new_shape):
|
|
585
|
+
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
|
|
586
|
+
>>> reshaped_input = self.reshape(input, new_shape)
|
|
587
|
+
>>>
|
|
588
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
|
|
589
|
+
>>> input = Tensor(np.array([0, 1, 2, 3])
|
|
590
|
+
>>> new_shape = (2, 2)
|
|
591
|
+
>>> net = TestDynamicShapeReshapeNet()
|
|
592
|
+
>>> output = net(input, new_shape)
|
|
593
|
+
>>> print(output)
|
|
594
|
+
[[0, 1], [2, 3]
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
@prim_attr_register
|
|
598
|
+
def __init__(self):
|
|
599
|
+
self.init_prim_io_names(inputs=["input"], outputs=["output"])
|
|
600
|
+
|
|
601
|
+
def check_shape(self, input_shape):
|
|
602
|
+
validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
|
|
603
|
+
|
|
604
|
+
def check_dtype(self, input_dtype):
|
|
605
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
|
609
|
+
"""
|
|
610
|
+
This op is used for dynamic shape testing. The only purpose of this operator is
|
|
611
|
+
that it will throw a value error if the input is dynamically shaped.
|
|
612
|
+
|
|
613
|
+
Inputs:
|
|
614
|
+
- **input** (Tensor) - The tensor used for testing.
|
|
615
|
+
|
|
616
|
+
Outputs:
|
|
617
|
+
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
618
|
+
|
|
619
|
+
Examples:
|
|
620
|
+
>>> # make a model, since dynamic shape operators must be in GRAPH_MODE
|
|
621
|
+
>>> import mindspore as ms
|
|
622
|
+
>>> import mindspore.nn as nn
|
|
623
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
624
|
+
>>> from mindspore.ops import operations as P
|
|
625
|
+
>>> class AssertDynamicShapeNet(nn.Cell):
|
|
626
|
+
>>> def __init__(self):
|
|
627
|
+
>>> super(AssertDynamicShapeNet, self).__init__()
|
|
628
|
+
>>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
|
629
|
+
>>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
|
|
630
|
+
>>>
|
|
631
|
+
>>> def construct(self, input, new_shape):
|
|
632
|
+
>>> dynamic_shape_input = self.convert_to_dynamic_shape(input)
|
|
633
|
+
>>> self.error_on_dynamic_shape_input(dynamic_shape_input)
|
|
634
|
+
>>>
|
|
635
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
|
|
636
|
+
>>> input = Tensor(np.array([0])
|
|
637
|
+
>>> net = TestDynamicShapeReshapeNet()
|
|
638
|
+
>>> output = net(input, new_shape)
|
|
639
|
+
ValueError: Input is dynamically shaped.
|
|
640
|
+
"""
|
|
641
|
+
|
|
642
|
+
@prim_attr_register
|
|
643
|
+
def __init__(self):
|
|
644
|
+
self.init_prim_io_names(inputs=["input"], outputs=["output"])
|
|
645
|
+
|
|
646
|
+
def infer_shape(self, input_shape):
|
|
647
|
+
shape = list(input_shape)
|
|
648
|
+
|
|
649
|
+
for dim in shape:
|
|
650
|
+
if dim == -1:
|
|
651
|
+
raise ValueError("Input is dynamically shaped.")
|
|
652
|
+
|
|
653
|
+
return input_shape
|
|
654
|
+
|
|
655
|
+
def infer_type(self, input_dtype):
|
|
656
|
+
"""Infer the dtype of input for ErrorOnDynamicShapeInput."""
|
|
657
|
+
validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
|
|
658
|
+
return input_dtype
|
|
659
|
+
|
|
660
|
+
def infer_value(self, input_tensor):
|
|
661
|
+
return input_tensor
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class SequenceMask(PrimitiveWithCheck):
|
|
665
|
+
"""
|
|
666
|
+
Returns a mask tensor representing the first N positions of each cell.
|
|
667
|
+
|
|
668
|
+
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
|
|
669
|
+
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
|
|
670
|
+
|
|
671
|
+
Inputs:
|
|
672
|
+
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
|
|
673
|
+
less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
|
|
674
|
+
Must be type int32 or int64.
|
|
675
|
+
|
|
676
|
+
- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
|
|
677
|
+
type as elements in `lengths`.
|
|
678
|
+
|
|
679
|
+
Outputs:
|
|
680
|
+
One mask tensor of shape lengths.shape + (maxlen,).
|
|
681
|
+
|
|
682
|
+
Supported Platforms:
|
|
683
|
+
``GPU`` ``CPU``
|
|
684
|
+
|
|
685
|
+
Examples:
|
|
686
|
+
>>> from mindspore import ops
|
|
687
|
+
>>> import numpy as np
|
|
688
|
+
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
|
|
689
|
+
>>> sequence_mask = ops.SequenceMask()
|
|
690
|
+
>>> output = sequence_mask(x, 3)
|
|
691
|
+
>>> print(output)
|
|
692
|
+
[[[True False False]
|
|
693
|
+
[True True True]]
|
|
694
|
+
[[True True False]
|
|
695
|
+
[False False False]]]
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
@prim_attr_register
|
|
699
|
+
def __init__(self):
|
|
700
|
+
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
|
|
701
|
+
|
|
702
|
+
def check_shape(self, lengths_shape, maxlen_shape):
|
|
703
|
+
validator.check("lengths_shape", len(lengths_shape), "", 0, validator.GT, self.name)
|
|
704
|
+
validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
|
|
705
|
+
|
|
706
|
+
def check_dtype(self, lengths_dtype, maxlen_dtype):
|
|
707
|
+
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
|
|
708
|
+
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class SyncBatchNorm(Primitive):
|
|
712
|
+
r"""
|
|
713
|
+
Sync Batch Normalization for input data and updated parameters.
|
|
714
|
+
|
|
715
|
+
Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is
|
|
716
|
+
widely used in convolutional neural networks. This operation applies Batch Normalization over input
|
|
717
|
+
to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
|
|
718
|
+
Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
|
719
|
+
It rescales and recenters the features using a mini-batch of data and the learned parameters which
|
|
720
|
+
can be described in the following formula,
|
|
721
|
+
|
|
722
|
+
.. math::
|
|
723
|
+
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
|
724
|
+
|
|
725
|
+
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
|
729
|
+
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
|
730
|
+
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
|
731
|
+
Momentum value must be [0, 1]. Default: 0.1.
|
|
732
|
+
group (str): The communication group to work on. Default: "sync_bn_group0".
|
|
733
|
+
device_num (int): The number of devices in each group. Default: 2.
|
|
734
|
+
|
|
735
|
+
Inputs:
|
|
736
|
+
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
|
737
|
+
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
|
738
|
+
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
739
|
+
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
|
740
|
+
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
|
|
741
|
+
|
|
742
|
+
Outputs:
|
|
743
|
+
Tuple of 5 Tensor, the normalized inputs and the updated parameters.
|
|
744
|
+
|
|
745
|
+
- **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
|
|
746
|
+
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
747
|
+
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
748
|
+
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
749
|
+
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
|
750
|
+
|
|
751
|
+
Supported Platforms:
|
|
752
|
+
``Ascend``
|
|
753
|
+
|
|
754
|
+
Examples:
|
|
755
|
+
>>> # This example should be run with multiple processes.
|
|
756
|
+
>>> # Please refer to nn.SyncBatchNorm for direct use.
|
|
757
|
+
>>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
|
|
758
|
+
>>> scale = Tensor(np.ones([2]), mindspore.float32)
|
|
759
|
+
>>> bias = Tensor(np.ones([2]), mindspore.float32)
|
|
760
|
+
>>> mean = Tensor(np.ones([2]), mindspore.float32)
|
|
761
|
+
>>> variance = Tensor(np.ones([2]), mindspore.float32)
|
|
762
|
+
>>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
|
|
763
|
+
>>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
|
|
764
|
+
>>> print(output)
|
|
765
|
+
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
766
|
+
[[ 1.00000000e+00, 1.00000000e+00],
|
|
767
|
+
[ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
|
|
768
|
+
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
|
769
|
+
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
|
770
|
+
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
|
771
|
+
[ 1.00000000e+00, 1.00000000e+00]))
|
|
772
|
+
"""
|
|
773
|
+
|
|
774
|
+
@prim_attr_register
|
|
775
|
+
def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
|
|
776
|
+
validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
|
|
777
|
+
validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
|
|
778
|
+
validator.check_isinstance("group", group, str)
|
|
779
|
+
validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
|
|
780
|
+
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
|
|
781
|
+
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
|
|
782
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
783
|
+
self.add_prim_attr('format', 'NCHW')
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
class Centralization(PrimitiveWithInfer):
|
|
787
|
+
"""
|
|
788
|
+
Computes centralization. y = x - mean(x, axis).
|
|
789
|
+
|
|
790
|
+
Note:
|
|
791
|
+
The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
|
|
792
|
+
|
|
793
|
+
Inputs:
|
|
794
|
+
- **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
|
|
795
|
+
- **axis** (Union[int, Tuple(int), List(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
|
|
796
|
+
Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)).
|
|
797
|
+
|
|
798
|
+
Outputs:
|
|
799
|
+
Tensor, has the same shape and dtype as the `input_x`.
|
|
800
|
+
|
|
801
|
+
Raises:
|
|
802
|
+
TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType.
|
|
803
|
+
TypeError: If `axis` has non-Int elements.
|
|
804
|
+
|
|
805
|
+
Supported Platforms:
|
|
806
|
+
``Ascend``
|
|
807
|
+
|
|
808
|
+
Examples:
|
|
809
|
+
>>> mindspore.set_seed(1)
|
|
810
|
+
>>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
|
811
|
+
>>> centralization = ops.Centralization()
|
|
812
|
+
>>> output = centralization(input_x, -1)
|
|
813
|
+
>>> print(output)
|
|
814
|
+
[[ 1.1180509 -1.1180508]
|
|
815
|
+
[ 0.2723984 -0.2723984]]
|
|
816
|
+
"""
|
|
817
|
+
|
|
818
|
+
__mindspore_signature__ = (
|
|
819
|
+
sig.make_sig('input_x'),
|
|
820
|
+
sig.make_sig('axis', default=())
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
@prim_attr_register
|
|
824
|
+
def __init__(self):
|
|
825
|
+
"""Initialize Centralization"""
|
|
826
|
+
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output'])
|
|
827
|
+
|
|
828
|
+
def __infer__(self, input_x, axis):
|
|
829
|
+
x_shape = list(input_x['shape'])
|
|
830
|
+
x_dtype = input_x['dtype']
|
|
831
|
+
axis_v = axis['value']
|
|
832
|
+
rank = len(x_shape)
|
|
833
|
+
|
|
834
|
+
args = {'input_x': input_x['dtype']}
|
|
835
|
+
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
|
836
|
+
|
|
837
|
+
if axis_v is None:
|
|
838
|
+
raise ValueError(f"For {self.name}, axis must be const.")
|
|
839
|
+
validator.check_value_type('axis', axis_v, [int, list, tuple], self.name)
|
|
840
|
+
|
|
841
|
+
if isinstance(axis_v, int):
|
|
842
|
+
validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, 'axis', self.name)
|
|
843
|
+
elif axis:
|
|
844
|
+
for index, one_axis in enumerate(axis_v):
|
|
845
|
+
validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name)
|
|
846
|
+
|
|
847
|
+
out = {'shape': x_shape,
|
|
848
|
+
'dtype': x_dtype,
|
|
849
|
+
'value': None}
|
|
850
|
+
return out
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
class StackInit(PrimitiveWithInfer):
|
|
854
|
+
"""
|
|
855
|
+
Create a stack that produces tensors in first-in last-out order.
|
|
856
|
+
|
|
857
|
+
After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
|
|
858
|
+
at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
index (int): The index of the stack. Default: 1.
|
|
862
|
+
|
|
863
|
+
Supported Platforms:
|
|
864
|
+
``Ascend``
|
|
865
|
+
|
|
866
|
+
Examples:
|
|
867
|
+
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
|
|
868
|
+
>>> index = 0
|
|
869
|
+
>>> stack = ops.StackInit(index)
|
|
870
|
+
>>> push = ops.StackPush(index)
|
|
871
|
+
>>> pop = ops.StackPop(index, x.shape, x.dtype)
|
|
872
|
+
>>> destroy = ops.StackDestroy(index)
|
|
873
|
+
>>> stack()
|
|
874
|
+
>>> push(x)
|
|
875
|
+
>>> y = pop()
|
|
876
|
+
>>> destroy()
|
|
877
|
+
>>> print(y)
|
|
878
|
+
[[1 3]
|
|
879
|
+
[2 0]]
|
|
880
|
+
"""
|
|
881
|
+
|
|
882
|
+
@prim_attr_register
|
|
883
|
+
def __init__(self, index=1):
|
|
884
|
+
"""StackInit"""
|
|
885
|
+
validator.check_value_type("index", index, [int], self.name)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
class StackPush(PrimitiveWithInfer):
|
|
889
|
+
"""
|
|
890
|
+
Push a tensor onto the stack.
|
|
891
|
+
|
|
892
|
+
Before `StackPush`, the stack should be created using `StackInit`.
|
|
893
|
+
Please refer to the usage in source code of `StackInit`.
|
|
894
|
+
|
|
895
|
+
Args:
|
|
896
|
+
index (int): The index of the stack. Default: 1.
|
|
897
|
+
|
|
898
|
+
Inputs:
|
|
899
|
+
- **input** (Tensor) - A tensor to be pushed onto the stack.
|
|
900
|
+
|
|
901
|
+
Supported Platforms:
|
|
902
|
+
``Ascend``
|
|
903
|
+
|
|
904
|
+
Examples:
|
|
905
|
+
Please refer to the usage of `StackInit`.
|
|
906
|
+
"""
|
|
907
|
+
|
|
908
|
+
@prim_attr_register
|
|
909
|
+
def __init__(self, index=1):
|
|
910
|
+
"""StackPush"""
|
|
911
|
+
validator.check_value_type("index", index, [int], self.name)
|
|
912
|
+
self.init_prim_io_names(inputs=['input'], outputs=[])
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
class StackPop(PrimitiveWithInfer):
|
|
916
|
+
"""
|
|
917
|
+
Pop the tensor at the top of the stack.
|
|
918
|
+
|
|
919
|
+
Before `StackPop`, the stack should be created using `StackInit`.
|
|
920
|
+
Please refer to the usage in source code of `StackInit`.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
index (int): The index of the stack. Default: 1.
|
|
924
|
+
shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
|
|
925
|
+
dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
|
|
926
|
+
|
|
927
|
+
Outputs:
|
|
928
|
+
- **output** (Tensor) - The tensor at the top of the stack.
|
|
929
|
+
|
|
930
|
+
Supported Platforms:
|
|
931
|
+
``Ascend``
|
|
932
|
+
|
|
933
|
+
Examples:
|
|
934
|
+
Please refer to the usage of `StackInit`.
|
|
935
|
+
"""
|
|
936
|
+
|
|
937
|
+
@prim_attr_register
|
|
938
|
+
def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
|
|
939
|
+
"""StackPop"""
|
|
940
|
+
validator.check_value_type("index", index, [int], self.name)
|
|
941
|
+
|
|
942
|
+
validator.check_value_type('shape type', shape, [list, tuple], self.name)
|
|
943
|
+
validator.check_int(len(np.array(shape).shape), 1, validator.EQ, "dim of shape", self.name)
|
|
944
|
+
for elem in shape:
|
|
945
|
+
validator.check_int(elem, 1, validator.GE, 'shape element', self.name)
|
|
946
|
+
validator.check_value_type('type of shape element', elem, [int], self.name)
|
|
947
|
+
|
|
948
|
+
validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
|
|
949
|
+
self.shape = shape
|
|
950
|
+
self.dtype = dtype
|
|
951
|
+
|
|
952
|
+
self.init_prim_io_names(inputs=[], outputs=['output'])
|
|
953
|
+
|
|
954
|
+
def __infer__(self):
|
|
955
|
+
return {'shape': (list(self.shape)),
|
|
956
|
+
'dtype': (self.dtype),
|
|
957
|
+
'value': None}
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
class StackDestroy(PrimitiveWithInfer):
|
|
961
|
+
"""
|
|
962
|
+
Destroy the stack.
|
|
963
|
+
|
|
964
|
+
Before `StackDestroy`, the stack should be created using `StackInit`.
|
|
965
|
+
Please refer to the usage in source code of `StackInit`.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
index (int): The index of the stack. Default: 1.
|
|
969
|
+
|
|
970
|
+
Supported Platforms:
|
|
971
|
+
``Ascend``
|
|
972
|
+
|
|
973
|
+
Examples:
|
|
974
|
+
Please refer to the usage of `StackInit`.
|
|
975
|
+
"""
|
|
976
|
+
|
|
977
|
+
@prim_attr_register
|
|
978
|
+
def __init__(self, index=1):
|
|
979
|
+
"""StackDestroy"""
|
|
980
|
+
validator.check_value_type("index", index, [int], self.name)
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
class DynamicStitch(PrimitiveWithCheck):
|
|
984
|
+
r"""
|
|
985
|
+
Interleave the values from the data tensors into a single tensor.
|
|
986
|
+
|
|
987
|
+
Inputs:
|
|
988
|
+
- **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
|
|
989
|
+
- **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
|
|
990
|
+
|
|
991
|
+
Outputs:
|
|
992
|
+
Tensor. A stacked Tensor with the same type as `data`.
|
|
993
|
+
|
|
994
|
+
Raises:
|
|
995
|
+
TypeError: If the data types of elements in `data` or `indices` are not the same.
|
|
996
|
+
ValueError: If the length of `data` or `indices` is not greater than 1.
|
|
997
|
+
|
|
998
|
+
Supported Platforms:
|
|
999
|
+
``Ascend``
|
|
1000
|
+
|
|
1001
|
+
Examples:
|
|
1002
|
+
>>> x1 = Tensor([6], mstype.int32)
|
|
1003
|
+
>>> x2 = Tensor(np.array([4, 1]), mstype.int32)
|
|
1004
|
+
>>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
|
|
1005
|
+
>>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
|
|
1006
|
+
>>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
|
|
1007
|
+
>>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
|
|
1008
|
+
>>> stitch = ops.DynamicStitch()
|
|
1009
|
+
>>> output = stitch([x1, x2, x3], [y1, y2, y3])
|
|
1010
|
+
>>> print(output)
|
|
1011
|
+
[[ 1 2]
|
|
1012
|
+
[11 12]
|
|
1013
|
+
[21 22]
|
|
1014
|
+
[31 32]
|
|
1015
|
+
[41 42]
|
|
1016
|
+
[51 52]
|
|
1017
|
+
[61 62]]
|
|
1018
|
+
"""
|
|
1019
|
+
|
|
1020
|
+
@prim_attr_register
|
|
1021
|
+
def __init__(self):
|
|
1022
|
+
"""Initialize DynamicStitch"""
|
|
1023
|
+
|
|
1024
|
+
def check_shape(self, indices_shape, data_shape):
|
|
1025
|
+
validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
|
|
1026
|
+
validator.check_int(len(indices_shape), 1, validator.GE, "len of indices_shape", self.name)
|
|
1027
|
+
indices_dim0 = len(indices_shape[0])
|
|
1028
|
+
indices_num = len(indices_shape)
|
|
1029
|
+
|
|
1030
|
+
validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
|
|
1031
|
+
validator.check_int(len(data_shape), 1, validator.GE, "len of data_shape", self.name)
|
|
1032
|
+
data_dim0 = len(data_shape[0])
|
|
1033
|
+
data_num = len(indices_shape)
|
|
1034
|
+
|
|
1035
|
+
validator.check("size of indices", indices_num, 'size of data', data_num, validator.EQ, self.name)
|
|
1036
|
+
|
|
1037
|
+
# shape of `data` must start with shape of `indices`
|
|
1038
|
+
for i in range(0, indices_num):
|
|
1039
|
+
indices_dim = len(indices_shape[i])
|
|
1040
|
+
data_dim = len(data_shape[i])
|
|
1041
|
+
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, validator.LE, self.name)
|
|
1042
|
+
if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
|
|
1043
|
+
raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
|
|
1044
|
+
|
|
1045
|
+
# the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
|
|
1046
|
+
base_extra = data_dim0 - indices_dim0
|
|
1047
|
+
for i in range(0, data_num):
|
|
1048
|
+
indices_dim = len(indices_shape[i])
|
|
1049
|
+
data_dim = len(data_shape[i])
|
|
1050
|
+
extra = data_dim - indices_dim
|
|
1051
|
+
validator.check(f"extra dim of data[{i}]", extra,
|
|
1052
|
+
f"extra dim of data[0]", base_extra, validator.EQ, self.name)
|
|
1053
|
+
validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
|
|
1054
|
+
f"data[{i}].shape[{len(indices_shape[i])}:]",
|
|
1055
|
+
data_shape[i][indices_dim:], validator.EQ, self.name)
|
|
1056
|
+
|
|
1057
|
+
out_shape = [-1] + data_shape[0][indices_dim0:]
|
|
1058
|
+
return out_shape
|
|
1059
|
+
|
|
1060
|
+
def check_dtype(self, indices_type, data_type):
|
|
1061
|
+
validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
|
|
1062
|
+
validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
|
|
1063
|
+
indices_num = len(indices_type)
|
|
1064
|
+
for i in range(0, indices_num):
|
|
1065
|
+
validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
|
|
1066
|
+
validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
|
|
1067
|
+
mstype.number_type + (mstype.bool_,), self.name)
|
|
1068
|
+
validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]",
|
|
1069
|
+
data_type[0], validator.EQ, self.name)
|
|
1070
|
+
return data_type[0]
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
class DynamicBroadcastGradientArgs(Primitive):
|
|
1074
|
+
"""
|
|
1075
|
+
Broadcast the two input shapes, return the dimensions that each need to be broadcast.
|
|
1076
|
+
|
|
1077
|
+
Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
|
|
1078
|
+
or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
|
|
1079
|
+
shape's value in that dimension.
|
|
1080
|
+
|
|
1081
|
+
Inputs:
|
|
1082
|
+
- **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
|
|
1083
|
+
uint32, uint64.
|
|
1084
|
+
- **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
|
|
1085
|
+
|
|
1086
|
+
Outputs:
|
|
1087
|
+
Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
|
|
1088
|
+
tensor.
|
|
1089
|
+
|
|
1090
|
+
- **r0** (Tensor) - The output shape is 1-D with the same type as s0.
|
|
1091
|
+
- **r1** (Tensor) - The output shape is 1-D with the same type as s0.
|
|
1092
|
+
|
|
1093
|
+
Raises:
|
|
1094
|
+
ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
|
|
1095
|
+
location.
|
|
1096
|
+
|
|
1097
|
+
Supported Platforms:
|
|
1098
|
+
``Ascend``
|
|
1099
|
+
|
|
1100
|
+
Examples:
|
|
1101
|
+
>>> shape0 = (4, 2, 1)
|
|
1102
|
+
>>> shape1 = (2, 7)
|
|
1103
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
1104
|
+
>>> args = _inner_ops.DynamicBroadcastGradientArgs()
|
|
1105
|
+
>>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
|
|
1106
|
+
>>> print(r0, r1)
|
|
1107
|
+
[2], [0]
|
|
1108
|
+
"""
|
|
1109
|
+
|
|
1110
|
+
@prim_attr_register
|
|
1111
|
+
def __init__(self):
|
|
1112
|
+
"""Init BroadcastGradientArgs"""
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
class DSDMatmul(PrimitiveWithInfer):
|
|
1116
|
+
"""
|
|
1117
|
+
The definition of the CusSquare primitive.
|
|
1118
|
+
"""
|
|
1119
|
+
|
|
1120
|
+
@prim_attr_register
|
|
1121
|
+
def __init__(self):
|
|
1122
|
+
self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y'])
|
|
1123
|
+
|
|
1124
|
+
def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape):
|
|
1125
|
+
batch_size = input_w1_shape[0]
|
|
1126
|
+
head = input_w1_shape[1]
|
|
1127
|
+
v_embedding = input_v_shape[1] * 16 // head
|
|
1128
|
+
seq_len = input_v_shape[0] * 16 // batch_size
|
|
1129
|
+
return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
|
|
1130
|
+
|
|
1131
|
+
def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3):
|
|
1132
|
+
return data_dtype1
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
class MatmulDDS(PrimitiveWithInfer):
|
|
1136
|
+
"""MatmulDDS definition"""
|
|
1137
|
+
|
|
1138
|
+
@prim_attr_register
|
|
1139
|
+
def __init__(self, bs, heads):
|
|
1140
|
+
"""init MatmulDDS"""
|
|
1141
|
+
self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'],
|
|
1142
|
+
outputs=['local_prob', 'global_prob'])
|
|
1143
|
+
|
|
1144
|
+
self.heads = heads
|
|
1145
|
+
|
|
1146
|
+
def infer_shape(self, q, k, local_mask, global_mask):
|
|
1147
|
+
seq_len = local_mask[0] * local_mask[-1]
|
|
1148
|
+
bs = q[1] * q[2] // seq_len
|
|
1149
|
+
global_size = seq_len // 4
|
|
1150
|
+
size_per_head = q[0] * q[-1] // self.heads
|
|
1151
|
+
heads = q[0] * q[-1] // size_per_head
|
|
1152
|
+
block_size = local_mask[1] * local_mask[2] // bs
|
|
1153
|
+
block_num = seq_len // block_size
|
|
1154
|
+
l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
|
|
1155
|
+
g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16)
|
|
1156
|
+
|
|
1157
|
+
return l_size, g_size
|
|
1158
|
+
|
|
1159
|
+
def infer_dtype(self, q, k, local_mask, global_mask):
|
|
1160
|
+
return q, q
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
class DSDGrad(PrimitiveWithInfer):
|
|
1164
|
+
"""
|
|
1165
|
+
The definition of the CusSquare primitive.
|
|
1166
|
+
"""
|
|
1167
|
+
|
|
1168
|
+
@prim_attr_register
|
|
1169
|
+
def __init__(self):
|
|
1170
|
+
self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'],
|
|
1171
|
+
outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm'])
|
|
1172
|
+
|
|
1173
|
+
def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape):
|
|
1174
|
+
return input_w1_shape, input_w2_shape, input_v_shape
|
|
1175
|
+
|
|
1176
|
+
def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5):
|
|
1177
|
+
return data_dtype1, data_dtype1, data_dtype1
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
class MatmulDDSGrad(PrimitiveWithInfer):
|
|
1181
|
+
"""MatmulDDS definition"""
|
|
1182
|
+
|
|
1183
|
+
@prim_attr_register
|
|
1184
|
+
def __init__(self):
|
|
1185
|
+
"""init MatmulDDS"""
|
|
1186
|
+
self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'],
|
|
1187
|
+
outputs=['dq', 'dk'])
|
|
1188
|
+
|
|
1189
|
+
def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
|
|
1190
|
+
k_size = (q[1], q[0], q[3], q[2])
|
|
1191
|
+
|
|
1192
|
+
return q, k_size
|
|
1193
|
+
|
|
1194
|
+
def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
|
|
1195
|
+
return q, k
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
class NonZeroWithValue(Primitive):
|
|
1199
|
+
"""
|
|
1200
|
+
Returns the value of elements that are non-zero (in row-major order - by dimension).
|
|
1201
|
+
|
|
1202
|
+
Inputs:
|
|
1203
|
+
- **x** (Tensor), input array of rank >= 2.
|
|
1204
|
+
|
|
1205
|
+
Outputs:
|
|
1206
|
+
elements that are non-zero.
|
|
1207
|
+
|
|
1208
|
+
Supported Platforms:
|
|
1209
|
+
``Ascend``
|
|
1210
|
+
|
|
1211
|
+
Examples:
|
|
1212
|
+
>>> op = NonZeroWithValue()
|
|
1213
|
+
>>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
|
|
1214
|
+
>>> value, index, count = op(data)
|
|
1215
|
+
>>> print(value)
|
|
1216
|
+
[1.0, 1.0]
|
|
1217
|
+
"""
|
|
1218
|
+
|
|
1219
|
+
@prim_attr_register
|
|
1220
|
+
def __init__(self, transpose=False):
|
|
1221
|
+
"""Initialize NonZeroWithValue"""
|
|
1222
|
+
validator.check_value_type("transpose", transpose, [bool], self.name)
|
|
1223
|
+
self.init_prim_io_names(inputs=['x'], outputs=['value', 'index', 'count'])
|
|
1224
|
+
|
|
1225
|
+
|
|
1226
|
+
class NonZeroWithValueShape(Primitive):
|
|
1227
|
+
"""
|
|
1228
|
+
Returns the value and index of elements that are non-zero (in row-major order - by dimension).
|
|
1229
|
+
|
|
1230
|
+
Inputs:
|
|
1231
|
+
- **x** (Tensor), input array of rank >= 2.
|
|
1232
|
+
|
|
1233
|
+
Outputs:
|
|
1234
|
+
elements that are non-zero.
|
|
1235
|
+
|
|
1236
|
+
Supported Platforms:
|
|
1237
|
+
``Ascend``
|
|
1238
|
+
|
|
1239
|
+
Examples:
|
|
1240
|
+
>>> non_zero = NonZeroWithValue()
|
|
1241
|
+
>>> op = NonZeroWithValueShape()
|
|
1242
|
+
>>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
|
|
1243
|
+
>>> value, index, count = non_zero(data)
|
|
1244
|
+
>>> out_value, out_index = op(value, index, count)
|
|
1245
|
+
>>> print(out_index)
|
|
1246
|
+
[[0, 1], [0, 2]]
|
|
1247
|
+
"""
|
|
1248
|
+
|
|
1249
|
+
@prim_attr_register
|
|
1250
|
+
def __init__(self):
|
|
1251
|
+
"""Initialize NonZeroWithValueShape"""
|
|
1252
|
+
self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index'])
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
class DecodeImage(PrimitiveWithInfer):
|
|
1256
|
+
"""
|
|
1257
|
+
Returns image data that parse from string Tensor.
|
|
1258
|
+
|
|
1259
|
+
Inputs:
|
|
1260
|
+
- **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image.
|
|
1261
|
+
|
|
1262
|
+
Outputs:
|
|
1263
|
+
A Tensor of type uint8, uint16, float.
|
|
1264
|
+
|
|
1265
|
+
Supported Platforms:
|
|
1266
|
+
``Ascend``
|
|
1267
|
+
|
|
1268
|
+
Examples:
|
|
1269
|
+
"""
|
|
1270
|
+
|
|
1271
|
+
@prim_attr_register
|
|
1272
|
+
def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
|
|
1273
|
+
_op_max_size=[8000000]):
|
|
1274
|
+
self.init_prim_io_names(inputs=["contents"], outputs=["image"])
|
|
1275
|
+
self.res_type = dtype
|
|
1276
|
+
|
|
1277
|
+
def infer_shape(self, x):
|
|
1278
|
+
return (-1, -1, 3)
|
|
1279
|
+
|
|
1280
|
+
def infer_dtype(self, x):
|
|
1281
|
+
return self.res_type
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
class SliceGetItem(Primitive):
|
|
1285
|
+
"""
|
|
1286
|
+
using SliceGetItem to get slice's attribute of 'start' 'stop' 'step'
|
|
1287
|
+
"""
|
|
1288
|
+
|
|
1289
|
+
@prim_attr_register
|
|
1290
|
+
def __init__(self):
|
|
1291
|
+
"""Initialize ScatterElements"""
|
|
1292
|
+
self.init_prim_io_names(inputs=['slice', 'attr'], outputs=['slice_item'])
|
|
1293
|
+
|
|
1294
|
+
def __call__(self, slice_value, value):
|
|
1295
|
+
if not isinstance(slice_value, slice):
|
|
1296
|
+
raise TypeError(
|
|
1297
|
+
"Primitive[SliceGetItem] only support to get a slice type element but got {}".format(slice_value))
|
|
1298
|
+
if value == "start":
|
|
1299
|
+
if hasattr(slice_value.start, "ndim") and slice_value.start.ndim == 1:
|
|
1300
|
+
return slice_value.start.item()
|
|
1301
|
+
return slice_value.start
|
|
1302
|
+
if value == "stop":
|
|
1303
|
+
if hasattr(slice_value.stop, "ndim") and slice_value.stop.ndim == 1:
|
|
1304
|
+
return slice_value.stop.item()
|
|
1305
|
+
return slice_value.stop
|
|
1306
|
+
if value == "step":
|
|
1307
|
+
if hasattr(slice_value.step, "ndim") and slice_value.step.ndim == 1:
|
|
1308
|
+
return slice_value.step.item()
|
|
1309
|
+
return slice_value.step
|
|
1310
|
+
raise AttributeError("\'slice\' object has no attribute {}".format(value))
|
|
1311
|
+
|
|
1312
|
+
|
|
1313
|
+
class DynamicBroadcastTo(Primitive):
|
|
1314
|
+
"""
|
|
1315
|
+
Broadcasts input tensor to a given shape.
|
|
1316
|
+
|
|
1317
|
+
Inputs:
|
|
1318
|
+
- **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
|
|
1319
|
+
float16, float32, int32, int8, uint8.
|
|
1320
|
+
The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
|
|
1321
|
+
- **shape** (Tensor): The target shape to broadcast.
|
|
1322
|
+
|
|
1323
|
+
Outputs:
|
|
1324
|
+
Tensor, with the given `shape` and the same data type as `input_x`.
|
|
1325
|
+
|
|
1326
|
+
Raises:
|
|
1327
|
+
ValueError: if the target and input shapes are incompatible.
|
|
1328
|
+
|
|
1329
|
+
Supported Platforms:
|
|
1330
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1331
|
+
"""
|
|
1332
|
+
|
|
1333
|
+
@prim_attr_register
|
|
1334
|
+
def __init__(self):
|
|
1335
|
+
"""Initialize DynamicBroadcastTo"""
|
|
1336
|
+
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
class DynamicResizeNearestNeighbor(Primitive):
|
|
1340
|
+
r"""
|
|
1341
|
+
Resizes the input tensor by using the nearest neighbor algorithm.
|
|
1342
|
+
|
|
1343
|
+
Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
|
|
1344
|
+
neighbor algorithm selects the value of the nearest point and does not consider the
|
|
1345
|
+
values of neighboring points at all, yielding a piecewise-constant interpolant.
|
|
1346
|
+
|
|
1347
|
+
Note:
|
|
1348
|
+
The operator supports dynamic shape.
|
|
1349
|
+
|
|
1350
|
+
Args:
|
|
1351
|
+
align_corners (bool): Whether the centers of the 4 corner pixels of the input
|
|
1352
|
+
and output tensors are aligned. Default: ``False``.
|
|
1353
|
+
|
|
1354
|
+
Inputs:
|
|
1355
|
+
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
|
1356
|
+
- **size** (Union[tuple, list]): The target size. The dimension of size must be 2.
|
|
1357
|
+
|
|
1358
|
+
Outputs:
|
|
1359
|
+
Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
|
|
1360
|
+
The data type is the same as the `input_x`.
|
|
1361
|
+
"""
|
|
1362
|
+
|
|
1363
|
+
@prim_attr_register
|
|
1364
|
+
def __init__(self, align_corners=False):
|
|
1365
|
+
"""Initialize ResizeNearestNeighbor"""
|
|
1366
|
+
validator.check_value_type("align_corners", align_corners, [bool], self.name)
|
|
1367
|
+
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
|
|
1368
|
+
|
|
1369
|
+
|
|
1370
|
+
class PsROIPooling(PrimitiveWithInfer):
|
|
1371
|
+
r"""
|
|
1372
|
+
Position Sensitive ROI-Pooling
|
|
1373
|
+
Inputs:
|
|
1374
|
+
- feature(Tensor)
|
|
1375
|
+
- rois(Tensor)
|
|
1376
|
+
|
|
1377
|
+
- **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`.
|
|
1378
|
+
- **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32.
|
|
1379
|
+
`rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms
|
|
1380
|
+
are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`.
|
|
1381
|
+
`image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y`
|
|
1382
|
+
coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y`
|
|
1383
|
+
represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively.
|
|
1384
|
+
|
|
1385
|
+
Outputs:
|
|
1386
|
+
- out shape(rois_num, out_channel, pool_height, pool_width), the result after pooling.
|
|
1387
|
+
- channel_map shape(rois_num, out_channel, pool_height, pool_width), use for back forward to compute grad
|
|
1388
|
+
Supported Platforms:
|
|
1389
|
+
``GPU``
|
|
1390
|
+
|
|
1391
|
+
Examples:
|
|
1392
|
+
>>> import mindspore
|
|
1393
|
+
>>> import numpy as np
|
|
1394
|
+
>>> from mindspore import Tensor
|
|
1395
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
1396
|
+
>>> features = np.random.randn(4, 21 * 7 * 7, 80, 48)
|
|
1397
|
+
>>> features = Tensor.from_numpy(features).astype(mindspore.float32)
|
|
1398
|
+
>>> rois = Tensor.from_numpy(
|
|
1399
|
+
>>> np.array([
|
|
1400
|
+
>>> [0.0000, 150.3563, 200.1320, 579.3563, 602.3452],
|
|
1401
|
+
>>> [1.0000, 657.1263, 302.8564, 762.4214, 567.9854],
|
|
1402
|
+
>>> [2.0000, 321.3122, 232.2410, 679.0281, 587.6346],
|
|
1403
|
+
>>> [3.0000, 664.1630, 387.4919, 778.7322, 562.7321],
|
|
1404
|
+
>>> ])).astype(mindspore.float32)
|
|
1405
|
+
>>> psRoIPooling = inner.PsROIPooling(pooled_height=7, pooled_width=7, num_rois=4,
|
|
1406
|
+
>>> spatial_scale=1.0/16, out_dim=21,
|
|
1407
|
+
>>> group_size=7)
|
|
1408
|
+
>>> out, channel_map = psRoIPooling(features, rois)
|
|
1409
|
+
>>> print(out.shape)
|
|
1410
|
+
[4, 21, 7, 7]
|
|
1411
|
+
>>> print(channel_map.shape)
|
|
1412
|
+
[4, 21, 7, 7]
|
|
1413
|
+
"""
|
|
1414
|
+
|
|
1415
|
+
@prim_attr_register
|
|
1416
|
+
def __init__(self, pooled_height, pooled_width, num_rois, spatial_scale, out_dim, group_size):
|
|
1417
|
+
"""Initialize PsROIPooling"""
|
|
1418
|
+
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
|
|
1419
|
+
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
|
|
1420
|
+
validator.check_value_type("num_rois", pooled_width, [int], self.name)
|
|
1421
|
+
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
|
1422
|
+
validator.check_value_type("out_dim", out_dim, [int], self.name)
|
|
1423
|
+
validator.check_value_type("group_size", group_size, [int], self.name)
|
|
1424
|
+
self.pooled_height = pooled_height
|
|
1425
|
+
self.pooled_width = pooled_width
|
|
1426
|
+
self.num_rois = num_rois
|
|
1427
|
+
self.spatial_scale = spatial_scale
|
|
1428
|
+
self.out_dim = out_dim
|
|
1429
|
+
self.group_size = group_size
|
|
1430
|
+
|
|
1431
|
+
def infer_shape(self, inputs_shape, rois_shape):
|
|
1432
|
+
output_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
|
|
1433
|
+
output_map_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
|
|
1434
|
+
return output_shape, output_map_shape
|
|
1435
|
+
|
|
1436
|
+
def infer_dtype(self, inputs_type, rois_type):
|
|
1437
|
+
map_type = mstype.TensorType(mstype.int32)
|
|
1438
|
+
return inputs_type, map_type
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
class ParallelResizeBilinear(PrimitiveWithInfer):
|
|
1442
|
+
"""ParallelResizeBilinear ops"""
|
|
1443
|
+
|
|
1444
|
+
@prim_attr_register
|
|
1445
|
+
def __init__(self, ori_image_size, split_size, src_start_w, dst_start_w, align_corners):
|
|
1446
|
+
"""Initialize ParallelResizeBilinear."""
|
|
1447
|
+
validator.check_value_type("ori_image_size", ori_image_size, [list, tuple], self.name)
|
|
1448
|
+
validator.check_value_type("split_size", split_size, [list, tuple], self.name)
|
|
1449
|
+
validator.check_int(len(split_size), 2, validator.EQ, "len of split_size", self.name)
|
|
1450
|
+
validator.check_value_type("src_start_w", src_start_w, [int], self.name)
|
|
1451
|
+
validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
|
|
1452
|
+
validator.check_value_type("align_corners", align_corners, [bool], self.name)
|
|
1453
|
+
self.ori_image_size = list(ori_image_size)
|
|
1454
|
+
self.split_size = list(split_size)
|
|
1455
|
+
self.src_start_w = src_start_w
|
|
1456
|
+
self.dst_start_w = dst_start_w
|
|
1457
|
+
self.align_corners = align_corners
|
|
1458
|
+
self.half_pixel_centers = False
|
|
1459
|
+
self.add_prim_attr('ori_image_size', self.ori_image_size)
|
|
1460
|
+
self.add_prim_attr('split_size', self.split_size)
|
|
1461
|
+
self.add_prim_attr('src_start_w', self.src_start_w)
|
|
1462
|
+
self.add_prim_attr('dst_start_w', self.dst_start_w)
|
|
1463
|
+
self.add_prim_attr('align_corners', self.align_corners)
|
|
1464
|
+
self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
|
|
1465
|
+
|
|
1466
|
+
def __infer__(self, x, size):
|
|
1467
|
+
size_val = size['value']
|
|
1468
|
+
x_shape = x['shape']
|
|
1469
|
+
x_dtype = x['dtype']
|
|
1470
|
+
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
|
|
1471
|
+
if size_val is None:
|
|
1472
|
+
raise ValueError("size must be const input")
|
|
1473
|
+
output_shape = [x_shape[0], x_shape[1], self.split_size[0], self.split_size[1]]
|
|
1474
|
+
|
|
1475
|
+
return {'shape': output_shape,
|
|
1476
|
+
'dtype': x_dtype,
|
|
1477
|
+
'value': None}
|
|
1478
|
+
|
|
1479
|
+
|
|
1480
|
+
class PartitionedCall(PrimitiveWithInfer):
|
|
1481
|
+
"""
|
|
1482
|
+
Pass the input tensors to the subgraph and return the output tensors.
|
|
1483
|
+
|
|
1484
|
+
Inputs:
|
|
1485
|
+
- **inputs** (Tuple), the input tensors, which will be passed to subgraph.
|
|
1486
|
+
|
|
1487
|
+
Outputs:
|
|
1488
|
+
- outputs(Tuple), the output tensor returned by subgraph.
|
|
1489
|
+
|
|
1490
|
+
Supported Platforms:
|
|
1491
|
+
``Ascend``
|
|
1492
|
+
|
|
1493
|
+
Examples:
|
|
1494
|
+
"""
|
|
1495
|
+
|
|
1496
|
+
@prim_attr_register
|
|
1497
|
+
def __init__(self, graph, executor_type=""):
|
|
1498
|
+
super(PartitionedCall, self).__init__(self.__class__.__name__)
|
|
1499
|
+
self.add_prim_attr("executor_type", executor_type)
|
|
1500
|
+
self.graph = graph
|
|
1501
|
+
|
|
1502
|
+
def infer_shape(self, *inputs):
|
|
1503
|
+
return NotImplementedError
|
|
1504
|
+
|
|
1505
|
+
def infer_dtype(self, *inputs):
|
|
1506
|
+
return NotImplementedError
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
class CellBackwardHook(PrimitiveWithInfer):
|
|
1510
|
+
r"""
|
|
1511
|
+
This operator is used to hook input gradient and output gradient of Cell object.
|
|
1512
|
+
|
|
1513
|
+
Note:
|
|
1514
|
+
This operator is only used in backward hook function of Cell object in pynative mode.
|
|
1515
|
+
|
|
1516
|
+
Args:
|
|
1517
|
+
cell_id (str): Used to identify which cell obj the hook function registered on. For example, 'nn.Add()' is a
|
|
1518
|
+
cell object.
|
|
1519
|
+
|
|
1520
|
+
Inputs:
|
|
1521
|
+
- **input** - The variable to hook.
|
|
1522
|
+
|
|
1523
|
+
Outputs:
|
|
1524
|
+
- **output** - Returns `input` directly. `CellBackwardHook` does not affect the forward result.
|
|
1525
|
+
|
|
1526
|
+
Supported Platforms:
|
|
1527
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1528
|
+
|
|
1529
|
+
Examples:
|
|
1530
|
+
>>> import mindspore as ms
|
|
1531
|
+
>>> from mindspore import Tensor
|
|
1532
|
+
>>> from mindspore.ops import GradOperation
|
|
1533
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
1534
|
+
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1535
|
+
>>> def hook_fn(grad):
|
|
1536
|
+
... print(grad)
|
|
1537
|
+
...
|
|
1538
|
+
>>> hook = inner.CellBackwardHook()
|
|
1539
|
+
>>> hook_fn_key = hook.register_backward_hook()
|
|
1540
|
+
>>> def hook_test(x, y):
|
|
1541
|
+
... z = x * y
|
|
1542
|
+
... z = hook(z)
|
|
1543
|
+
... z = z * y
|
|
1544
|
+
... return z
|
|
1545
|
+
...
|
|
1546
|
+
>>> grad_all = GradOperation(get_all=True)
|
|
1547
|
+
>>> def backward(x, y):
|
|
1548
|
+
... return grad_all(hook_test)(x, y)
|
|
1549
|
+
...
|
|
1550
|
+
>>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
|
|
1551
|
+
(Tensor(shape=[], dtype=Float32, value= 2),)
|
|
1552
|
+
>>> print(output)
|
|
1553
|
+
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
|
1554
|
+
>>> hook.remove_backward_hook(hook_fn_key)
|
|
1555
|
+
>>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
|
|
1556
|
+
>>> print(output)
|
|
1557
|
+
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
|
1558
|
+
"""
|
|
1559
|
+
|
|
1560
|
+
def __init__(self, cell_id="", cell=None, hook_dict=None):
|
|
1561
|
+
"""Initialize CellBackwardHook"""
|
|
1562
|
+
super(CellBackwardHook, self).__init__(self.__class__.__name__)
|
|
1563
|
+
self.cell_id = cell_id
|
|
1564
|
+
self.cell = cell
|
|
1565
|
+
self.hook_dict = weakref.ref(hook_dict)
|
|
1566
|
+
self.add_prim_attr("cell_id", cell_id)
|
|
1567
|
+
self.grad_output = None
|
|
1568
|
+
|
|
1569
|
+
def __call__(self, *args):
|
|
1570
|
+
# If args is empty, just return.
|
|
1571
|
+
if not args:
|
|
1572
|
+
return args
|
|
1573
|
+
return _run_op(self, self.name, args)
|
|
1574
|
+
|
|
1575
|
+
def infer_shape(self, *inputs_shape):
|
|
1576
|
+
if len(inputs_shape) == 1:
|
|
1577
|
+
return inputs_shape[0]
|
|
1578
|
+
return inputs_shape
|
|
1579
|
+
|
|
1580
|
+
def infer_dtype(self, *inputs_type):
|
|
1581
|
+
if len(inputs_type) == 1:
|
|
1582
|
+
return inputs_type[0]
|
|
1583
|
+
return inputs_type
|
|
1584
|
+
|
|
1585
|
+
def register_backward_hook(self):
|
|
1586
|
+
"""
|
|
1587
|
+
Register the backward hook function.
|
|
1588
|
+
|
|
1589
|
+
Args:
|
|
1590
|
+
None
|
|
1591
|
+
|
|
1592
|
+
Returns:
|
|
1593
|
+
None
|
|
1594
|
+
|
|
1595
|
+
Supported Platforms:
|
|
1596
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1597
|
+
"""
|
|
1598
|
+
|
|
1599
|
+
def hook_backward_grad(grad):
|
|
1600
|
+
if self.grad_output is None:
|
|
1601
|
+
self.grad_output = grad
|
|
1602
|
+
# Indicates the first time of call backward hook, and need to wait for the second time call
|
|
1603
|
+
return self.cell_id
|
|
1604
|
+
backward_hook_grad_input = grad
|
|
1605
|
+
if self.hook_dict():
|
|
1606
|
+
backward_hooks = self.hook_dict().values()
|
|
1607
|
+
for hook in backward_hooks:
|
|
1608
|
+
res = hook(self.cell, backward_hook_grad_input, self.grad_output)
|
|
1609
|
+
if res is None:
|
|
1610
|
+
continue
|
|
1611
|
+
if not isinstance(res, tuple):
|
|
1612
|
+
res = (res,)
|
|
1613
|
+
if len(res) != len(grad):
|
|
1614
|
+
raise TypeError(
|
|
1615
|
+
"The backward hook return value size is {} not equal to expect grad input size {}".format(
|
|
1616
|
+
len(res), len(grad)))
|
|
1617
|
+
backward_hook_grad_input = res
|
|
1618
|
+
self.grad_output = None
|
|
1619
|
+
return backward_hook_grad_input
|
|
1620
|
+
|
|
1621
|
+
self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
|
|
1622
|
+
|
|
1623
|
+
def register_backward_pre_hook(self):
|
|
1624
|
+
"""
|
|
1625
|
+
Register the backward pre hook function.
|
|
1626
|
+
|
|
1627
|
+
Args:
|
|
1628
|
+
None
|
|
1629
|
+
|
|
1630
|
+
Returns:
|
|
1631
|
+
None
|
|
1632
|
+
|
|
1633
|
+
Supported Platforms:
|
|
1634
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1635
|
+
"""
|
|
1636
|
+
|
|
1637
|
+
def hook_backward_pre_grad(grad):
|
|
1638
|
+
backward_pre_hook_grad = grad
|
|
1639
|
+
if self.hook_dict():
|
|
1640
|
+
backward_pre_hooks = self.hook_dict().values()
|
|
1641
|
+
for hook in backward_pre_hooks:
|
|
1642
|
+
res = hook(self.cell, backward_pre_hook_grad)
|
|
1643
|
+
if res is None:
|
|
1644
|
+
continue
|
|
1645
|
+
if not isinstance(res, tuple):
|
|
1646
|
+
res = (res,)
|
|
1647
|
+
if len(res) != len(grad):
|
|
1648
|
+
raise TypeError(
|
|
1649
|
+
"The backward pre hook return value size is {} not equal to expect output size {}".format(
|
|
1650
|
+
len(res), len(grad)))
|
|
1651
|
+
backward_pre_hook_grad = res
|
|
1652
|
+
return backward_pre_hook_grad
|
|
1653
|
+
|
|
1654
|
+
self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
class Format(PrimitiveWithInfer):
|
|
1658
|
+
r"""
|
|
1659
|
+
This operator is used to format a string.
|
|
1660
|
+
|
|
1661
|
+
Note:
|
|
1662
|
+
Current not supported to using by customer.
|
|
1663
|
+
Only support convert str.format() in user code and it will be converted to be Format
|
|
1664
|
+
operation by ME-Compiler automatically.
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
Inputs:
|
|
1668
|
+
- **input** -
|
|
1669
|
+
string : the string to be formatted.
|
|
1670
|
+
args : the format args.
|
|
1671
|
+
|
|
1672
|
+
Outputs:
|
|
1673
|
+
- **output** - Returns formatted string.
|
|
1674
|
+
|
|
1675
|
+
Supported Platforms:
|
|
1676
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1677
|
+
"""
|
|
1678
|
+
|
|
1679
|
+
@prim_attr_register
|
|
1680
|
+
def __init__(self):
|
|
1681
|
+
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
|
1682
|
+
|
|
1683
|
+
def __infer__(self, str_, *var):
|
|
1684
|
+
def check_variable(str_, var):
|
|
1685
|
+
if _check_contains_variable(str_['dtype'], str_['value']):
|
|
1686
|
+
return True
|
|
1687
|
+
|
|
1688
|
+
for item in var:
|
|
1689
|
+
if _check_contains_variable(item['dtype'], item['value']):
|
|
1690
|
+
return True
|
|
1691
|
+
return False
|
|
1692
|
+
|
|
1693
|
+
if check_variable(str_, var):
|
|
1694
|
+
return {'dtype': mstype.string, 'shape': [], 'value': None}
|
|
1695
|
+
|
|
1696
|
+
str_value = str_['value']
|
|
1697
|
+
kwargs = dict()
|
|
1698
|
+
var_value = list()
|
|
1699
|
+
|
|
1700
|
+
for item in var:
|
|
1701
|
+
if isinstance(item["dtype"], typing.Keyword):
|
|
1702
|
+
kwargs.update(item["value"])
|
|
1703
|
+
var_value.append(item["value"])
|
|
1704
|
+
|
|
1705
|
+
value = str_value.format(*var_value, **kwargs)
|
|
1706
|
+
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
|
1707
|
+
|
|
1708
|
+
|
|
1709
|
+
class FlattenConcat(Primitive):
|
|
1710
|
+
"""
|
|
1711
|
+
Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
|
|
1712
|
+
|
|
1713
|
+
Args:
|
|
1714
|
+
fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
|
|
1715
|
+
|
|
1716
|
+
Inputs:
|
|
1717
|
+
- **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
|
|
1718
|
+
|
|
1719
|
+
Outputs:
|
|
1720
|
+
tuple[Tensor], result chunk tensors.
|
|
1721
|
+
|
|
1722
|
+
Supported Platforms:
|
|
1723
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1724
|
+
|
|
1725
|
+
Examples:
|
|
1726
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
1727
|
+
>>> t1 = Tensor(np.array([1]).astype(np.float32))
|
|
1728
|
+
>>> t2 = Tensor(np.array([2]).astype(np.float32))
|
|
1729
|
+
>>> t3 = Tensor(np.array([3]).astype(np.float64))
|
|
1730
|
+
>>> t4 = Tensor(np.array([4]).astype(np.float32))
|
|
1731
|
+
>>> t5 = Tensor(np.array([5]).astype(np.float64))
|
|
1732
|
+
>>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
|
|
1733
|
+
>>> print(chunks[0].asnumpy())
|
|
1734
|
+
>>> print(chunks[1].asnumpy())
|
|
1735
|
+
[1. 2. 4.]
|
|
1736
|
+
[3. 5.]
|
|
1737
|
+
"""
|
|
1738
|
+
|
|
1739
|
+
@prim_attr_register
|
|
1740
|
+
def __init__(self, fusion_size=0):
|
|
1741
|
+
"""Initialize FlattenConcat"""
|
|
1742
|
+
validator.check_non_negative_int(fusion_size, 'fusion_size', self.name)
|
|
1743
|
+
self.fusion_size = fusion_size
|
|
1744
|
+
self.add_prim_attr('fusion_size', fusion_size)
|
|
1745
|
+
|
|
1746
|
+
|
|
1747
|
+
class KMeansCentroids(PrimitiveWithInfer):
|
|
1748
|
+
"""
|
|
1749
|
+
Calculate the segment_sum, segment_count, kmean_total_sum that are clustering results
|
|
1750
|
+
|
|
1751
|
+
Args:
|
|
1752
|
+
use_actual_distance (bool): A bool value to decide whether do complete calculation of distance.
|
|
1753
|
+
|
|
1754
|
+
Inputs:
|
|
1755
|
+
- **x** (Tensor(float32)) - Input data used for clustering
|
|
1756
|
+
- **y** (Tensor(float32)) - Initial centroids of clutering
|
|
1757
|
+
- **sum_square_y** (Tensor(float32)) - The result of preprocessing such as square, reduce and transpose of y
|
|
1758
|
+
- **sum_square_x** (Tensor(float32)) - The result of preprocessing such as square and reduce of x
|
|
1759
|
+
|
|
1760
|
+
Outputs:
|
|
1761
|
+
- **segment_sum** (Tensor(float32)) - Clustering result w.r.t. each centroid
|
|
1762
|
+
- **segment_count** (Tensor(float32)) - Clustering count w.r.t. each centroid
|
|
1763
|
+
- **kmean_total_sum** (Tensor(float32)) - The sum of the distances from all vectors to ther nearest centroid
|
|
1764
|
+
|
|
1765
|
+
Supported Platforms:
|
|
1766
|
+
''Ascend''
|
|
1767
|
+
|
|
1768
|
+
Examples:
|
|
1769
|
+
>>> import numpy as np
|
|
1770
|
+
>>> import mindspore as ms
|
|
1771
|
+
>>> import mindspore.common.dtype as mstype
|
|
1772
|
+
>>> import mindspore.nn as nn
|
|
1773
|
+
>>> from mindspore import Tensor
|
|
1774
|
+
>>> from mindspore.ops import operations as P
|
|
1775
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
|
|
1776
|
+
|
|
1777
|
+
>>> class Net(nn.Cell):
|
|
1778
|
+
>>> def __init__(self):
|
|
1779
|
+
>>> super(Net, self).__init__()
|
|
1780
|
+
>>> self.reduce_sum = P.ReduceSUm(keep_dims=True)
|
|
1781
|
+
>>> self.square = P.Square()
|
|
1782
|
+
>>> self.transpose = P.Transpose()
|
|
1783
|
+
>>> self.k_means_centroids = P.KMeansCentroids(True)
|
|
1784
|
+
|
|
1785
|
+
>>> def construct(self, x, y):
|
|
1786
|
+
>>> p1 = self.reduce_sum(self.square(x), -1)
|
|
1787
|
+
>>> p2 = self.transpose(self.reduce_sum(self.square(y), -1), (1, 0))
|
|
1788
|
+
>>> return self.k_means_centroids(x, y, p2, p1)
|
|
1789
|
+
|
|
1790
|
+
>>> def test_net():
|
|
1791
|
+
>>> data_type = np.float32
|
|
1792
|
+
>>> x = Tensor(np.random.uniform(-10, 10, (65536, 128)).astype(data_type))
|
|
1793
|
+
>>> y = P.Ones()((1048576, 128), mstype.float32)
|
|
1794
|
+
>>> net = Net()
|
|
1795
|
+
>>> local_sum, local_count, local_avg_distance = net(x, y)
|
|
1796
|
+
"""
|
|
1797
|
+
|
|
1798
|
+
@prim_attr_register
|
|
1799
|
+
def __init__(self, use_actual_distance):
|
|
1800
|
+
validator.check_value_type('use_actual_distance', use_actual_distance, [bool], self.name)
|
|
1801
|
+
self.init_prim_io_names(inputs=['x', 'y', 'sum_square_y', 'sum_square_x'],
|
|
1802
|
+
outputs=['segment_sum', 'segment_count', 'kmean_total_sum'])
|
|
1803
|
+
|
|
1804
|
+
def infer_shape(self, x_shape, y_shape, sum_square_y_shape, sum_square_x_shape):
|
|
1805
|
+
"""infer shape of primitive"""
|
|
1806
|
+
expected_shape_size = 2
|
|
1807
|
+
validator.check_int(len(x_shape), expected_shape_size, validator.EQ, "dims of x", self.name)
|
|
1808
|
+
validator.check_int(len(y_shape), expected_shape_size, validator.EQ, "dims of y", self.name)
|
|
1809
|
+
validator.check_int(len(sum_square_y_shape), expected_shape_size, validator.EQ,
|
|
1810
|
+
"dims of sum_square_y", self.name)
|
|
1811
|
+
validator.check_int(len(sum_square_x_shape), expected_shape_size, validator.EQ,
|
|
1812
|
+
"dims of sum_square_x", self.name)
|
|
1813
|
+
|
|
1814
|
+
validator.check_int(x_shape[1], y_shape[1], validator.EQ,
|
|
1815
|
+
"the second dim of x and the second dim of y", self.name)
|
|
1816
|
+
validator.check_int(y_shape[0], sum_square_y_shape[1], validator.EQ,
|
|
1817
|
+
"the first dim of y and the second dim of sum_square_y", self.name)
|
|
1818
|
+
validator.check_int(x_shape[0], sum_square_x_shape[0], validator.EQ,
|
|
1819
|
+
"the first dim of x and the first dim of sum_square_x", self.name)
|
|
1820
|
+
validator.check_int(sum_square_y_shape[0], sum_square_x_shape[1], validator.EQ,
|
|
1821
|
+
"the first dim of sum_square_y and the first dim of sum_square_x",
|
|
1822
|
+
self.name)
|
|
1823
|
+
validator.check_int(sum_square_y_shape[0], 1, validator.EQ,
|
|
1824
|
+
"the first dim of sum_square_y", self.name)
|
|
1825
|
+
|
|
1826
|
+
k = y_shape[0]
|
|
1827
|
+
em_size = x_shape[1]
|
|
1828
|
+
return (k, em_size), (k, 1), (1)
|
|
1829
|
+
|
|
1830
|
+
|
|
1831
|
+
class ClipByNorm(PrimitiveWithInfer):
|
|
1832
|
+
r"""
|
|
1833
|
+
Clips tensor values to a maximum :math:`L_2`-norm.
|
|
1834
|
+
|
|
1835
|
+
Note:
|
|
1836
|
+
The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
|
|
1837
|
+
tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
|
|
1838
|
+
|
|
1839
|
+
.. math::
|
|
1840
|
+
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
|
1841
|
+
|
|
1842
|
+
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
|
1843
|
+
|
|
1844
|
+
Args:
|
|
1845
|
+
axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
|
|
1846
|
+
Default: ``None``, all dimensions to calculate.
|
|
1847
|
+
|
|
1848
|
+
Inputs:
|
|
1849
|
+
- **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
|
|
1850
|
+
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
|
1851
|
+
Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
|
|
1852
|
+
|
|
1853
|
+
Outputs:
|
|
1854
|
+
Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
|
|
1855
|
+
|
|
1856
|
+
Raises:
|
|
1857
|
+
TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
|
|
1858
|
+
TypeError: If dtype of `x` is neither float16 nor float32.
|
|
1859
|
+
TypeError: If dtype of `clip_norm` is neither float16 nor float32.
|
|
1860
|
+
|
|
1861
|
+
Supported Platforms:
|
|
1862
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1863
|
+
|
|
1864
|
+
Examples:
|
|
1865
|
+
>>> import numpy as np
|
|
1866
|
+
>>> import mindspore
|
|
1867
|
+
>>> from mindspore import Tensor
|
|
1868
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
1869
|
+
>>> clip_by_norm = inner.ClipByNorm()
|
|
1870
|
+
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
|
1871
|
+
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
|
1872
|
+
>>> output = clip_by_norm(x, clip_norm)
|
|
1873
|
+
>>> print(output.shape)
|
|
1874
|
+
(4, 16)
|
|
1875
|
+
"""
|
|
1876
|
+
|
|
1877
|
+
@prim_attr_register
|
|
1878
|
+
def __init__(self, axis=None):
|
|
1879
|
+
"""Initialize ClipByNorm"""
|
|
1880
|
+
self.axis = () if axis is None else axis
|
|
1881
|
+
validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
|
|
1882
|
+
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
|
1883
|
+
for i, value in enumerate(axis_check):
|
|
1884
|
+
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
|
1885
|
+
self.init_attrs['axis'] = self.axis
|
|
1886
|
+
self.add_prim_attr('axis', self.axis)
|
|
1887
|
+
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
|
1888
|
+
|
|
1889
|
+
def infer_shape(self, x_shape, clip_norm_shape):
|
|
1890
|
+
"""Infer shape for ClipByNorm"""
|
|
1891
|
+
x_dim = len(x_shape)
|
|
1892
|
+
axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
|
1893
|
+
for _, value in enumerate(axis):
|
|
1894
|
+
validator.check_int_range(value, -x_dim, x_dim, validator.INC_LEFT, 'axis', self.name)
|
|
1895
|
+
return x_shape
|
|
1896
|
+
|
|
1897
|
+
def infer_dtype(self, x_type, clip_norm_type):
|
|
1898
|
+
"""Infer data type for ClipByNorm"""
|
|
1899
|
+
validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
|
|
1900
|
+
validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
|
|
1901
|
+
[mstype.float16, mstype.float32], self.name)
|
|
1902
|
+
return mstype.float32
|
|
1903
|
+
|
|
1904
|
+
|
|
1905
|
+
class TopTypeof(Primitive):
|
|
1906
|
+
"""
|
|
1907
|
+
Internal primitive method, to speed up mindspore.ops.typeof.
|
|
1908
|
+
|
|
1909
|
+
Returns the top type of the input data.
|
|
1910
|
+
|
|
1911
|
+
In Pynative mode, returns the top type in cache.
|
|
1912
|
+
|
|
1913
|
+
Supported Platforms:
|
|
1914
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1915
|
+
"""
|
|
1916
|
+
|
|
1917
|
+
@prim_attr_register
|
|
1918
|
+
def __init__(self):
|
|
1919
|
+
self.prim = Primitive('TopTypeof')
|
|
1920
|
+
self.typeof_cache = {
|
|
1921
|
+
'slice': mstype.Slice(),
|
|
1922
|
+
'list': mstype.List(),
|
|
1923
|
+
'tuple': mstype.Tuple(),
|
|
1924
|
+
'Tensor': mstype.tensor_type,
|
|
1925
|
+
'NoneType': mstype.NoneType(),
|
|
1926
|
+
'int': mstype.Int(),
|
|
1927
|
+
'bool': mstype.Bool(),
|
|
1928
|
+
'ellipsis': mstype.Ellipsis_(),
|
|
1929
|
+
'dict': mstype.Dict()
|
|
1930
|
+
}
|
|
1931
|
+
|
|
1932
|
+
def __call__(self, x):
|
|
1933
|
+
index_type = type(x).__name__
|
|
1934
|
+
if 'Tensor' in index_type:
|
|
1935
|
+
index_type = 'Tensor'
|
|
1936
|
+
if index_type in self.typeof_cache:
|
|
1937
|
+
return self.typeof_cache.get(index_type)
|
|
1938
|
+
return _pynative_executor.constant_folding(self.prim, x)
|
|
1939
|
+
|
|
1940
|
+
|
|
1941
|
+
class MixedPrecisionCast(Primitive):
|
|
1942
|
+
r"""
|
|
1943
|
+
Internal primitive method, to achieve mindspore.functional.mixed_precision_cast.
|
|
1944
|
+
|
|
1945
|
+
Note:
|
|
1946
|
+
This internal primitive method used to do mixed precision conversion.
|
|
1947
|
+
Only the input object with float dtype will be cast.
|
|
1948
|
+
|
|
1949
|
+
Inputs:
|
|
1950
|
+
- **dtype** (Union[Float16, Float32]) - The data type of the output object.
|
|
1951
|
+
- **input** (Union[Tensor, Tuple, Dictionary, KeywordArg]) - The object to be cast.
|
|
1952
|
+
|
|
1953
|
+
Outputs:
|
|
1954
|
+
Object, its dtype is the same as `dtype` and shape is the same as 'input'.
|
|
1955
|
+
|
|
1956
|
+
Supported Platforms:
|
|
1957
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1958
|
+
|
|
1959
|
+
Examples:
|
|
1960
|
+
>>> import numpy as np
|
|
1961
|
+
>>> from mindspore import Tensor
|
|
1962
|
+
>>> from mindspore import dtype as mstype
|
|
1963
|
+
>>> from mindspore.ops.operations import _inner_ops as inner
|
|
1964
|
+
>>> x = Tensor(np.ones([2, 3], dtype=np.float32))
|
|
1965
|
+
>>> out = inner.MixedPrecisionCast(mstype.float16, x)
|
|
1966
|
+
>>> print(out.dtype)
|
|
1967
|
+
Float16
|
|
1968
|
+
"""
|
|
1969
|
+
|
|
1970
|
+
@prim_attr_register
|
|
1971
|
+
def __init__(self):
|
|
1972
|
+
"""Initialize MixedPrecisionCast"""
|
|
1973
|
+
self.init_prim_io_names(inputs=['dst_dtype', 'input_x'], outputs=['output'])
|
|
1974
|
+
self.cast = Cast()
|
|
1975
|
+
self.hyper_map = C.HyperMap()
|
|
1976
|
+
|
|
1977
|
+
def __call__(self, dst_dtype, x):
|
|
1978
|
+
def cast_inner(data):
|
|
1979
|
+
if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
|
|
1980
|
+
mstype.float64, mstype.bfloat16):
|
|
1981
|
+
return self.cast(data, dst_dtype)
|
|
1982
|
+
return data
|
|
1983
|
+
|
|
1984
|
+
return self.hyper_map(cast_inner, x)
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
class CheckBprop(PrimitiveWithInfer):
|
|
1988
|
+
"""
|
|
1989
|
+
Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
|
|
1990
|
+
|
|
1991
|
+
Args:
|
|
1992
|
+
prim_to_check (str): The name of the primitive being checked. Default: ''.
|
|
1993
|
+
|
|
1994
|
+
Inputs:
|
|
1995
|
+
- **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
|
|
1996
|
+
- **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
|
|
1997
|
+
|
|
1998
|
+
Outputs:
|
|
1999
|
+
Tuple[Tensor], the `input_x`,
|
|
2000
|
+
if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
|
|
2001
|
+
|
|
2002
|
+
Raises:
|
|
2003
|
+
TypeError: If `input_x` or `input_y` is not a Tensor.
|
|
2004
|
+
|
|
2005
|
+
Supported Platforms:
|
|
2006
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2007
|
+
|
|
2008
|
+
Examples:
|
|
2009
|
+
>>> class Net(nn.Cell):
|
|
2010
|
+
... def __init__(self):
|
|
2011
|
+
... super(Net, self).__init__()
|
|
2012
|
+
... self.op = ops.CheckBprop()
|
|
2013
|
+
... def construct(self, x, y):
|
|
2014
|
+
... return self.op(x, y)
|
|
2015
|
+
...
|
|
2016
|
+
>>> net = Net()
|
|
2017
|
+
>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
|
|
2018
|
+
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
|
|
2019
|
+
>>> output = net(input_x, input_y)
|
|
2020
|
+
>>> print(output)
|
|
2021
|
+
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
2022
|
+
[[ 2.00000000e+00, 2.00000000e+00],
|
|
2023
|
+
[ 2.00000000e+00, 2.00000000e+00]]),)
|
|
2024
|
+
"""
|
|
2025
|
+
|
|
2026
|
+
@prim_attr_register
|
|
2027
|
+
def __init__(self, prim_to_check=""):
|
|
2028
|
+
"""Initialize CheckBprop"""
|
|
2029
|
+
self.prim_to_check = prim_to_check
|
|
2030
|
+
|
|
2031
|
+
def infer_shape(self, xshapes, yshapes):
|
|
2032
|
+
"""infer shape"""
|
|
2033
|
+
tips = f"user defined method 'bprop'"
|
|
2034
|
+
validator.check_value_type('grads', xshapes, (tuple,), tips)
|
|
2035
|
+
validator.check_value_type('params', yshapes, (tuple,), tips)
|
|
2036
|
+
if not len(xshapes) == len(yshapes):
|
|
2037
|
+
raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
|
|
2038
|
+
f"the number of input arguments except 'out' and 'dout', "
|
|
2039
|
+
f"which is:{len(yshapes)} but got {len(xshapes)}.")
|
|
2040
|
+
|
|
2041
|
+
def shape_equal(shape1, shape2):
|
|
2042
|
+
if len(shape1) != len(shape2):
|
|
2043
|
+
return False
|
|
2044
|
+
for shape_axis1, shape_axis2 in zip(shape1, shape2):
|
|
2045
|
+
if shape_axis1 == -1 or shape_axis2 == -1:
|
|
2046
|
+
continue
|
|
2047
|
+
if shape_axis1 != shape_axis2:
|
|
2048
|
+
return False
|
|
2049
|
+
return True
|
|
2050
|
+
|
|
2051
|
+
for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
|
|
2052
|
+
if not xshape or not yshape:
|
|
2053
|
+
continue
|
|
2054
|
+
|
|
2055
|
+
if not shape_equal(xshape, yshape):
|
|
2056
|
+
raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
|
|
2057
|
+
f"should have the same shape as the {i}th argument, "
|
|
2058
|
+
f"which is:{yshape}, but got: {xshape}.")
|
|
2059
|
+
return xshapes
|
|
2060
|
+
|
|
2061
|
+
def infer_dtype(self, xdtypes, ydtypes):
|
|
2062
|
+
"""infer dtype"""
|
|
2063
|
+
tips = f"user defined method 'bprop'"
|
|
2064
|
+
validator.check_value_type('grads', xdtypes, (tuple,), tips)
|
|
2065
|
+
validator.check_value_type('params', ydtypes, (tuple,), tips)
|
|
2066
|
+
if not len(xdtypes) == len(ydtypes):
|
|
2067
|
+
raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
|
|
2068
|
+
f"the number of input arguments except 'out' and 'dout', "
|
|
2069
|
+
f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
|
|
2070
|
+
checking_range = len(ydtypes)
|
|
2071
|
+
for i in range(checking_range):
|
|
2072
|
+
xdtype = xdtypes[i]
|
|
2073
|
+
ydtype = ydtypes[i]
|
|
2074
|
+
if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
|
|
2075
|
+
continue
|
|
2076
|
+
if isinstance(ydtype, mstype.FunctionType):
|
|
2077
|
+
if not isinstance(xdtype, mstype.EnvType):
|
|
2078
|
+
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
|
|
2079
|
+
f"should be {mstype.EnvType}, but got {xdtype}.")
|
|
2080
|
+
if xdtype != ydtype:
|
|
2081
|
+
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
|
|
2082
|
+
f"should have the same dtype as the {i}th argument, "
|
|
2083
|
+
f"which is:{ydtype}, but got: {xdtype}.")
|
|
2084
|
+
return xdtypes
|
|
2085
|
+
|
|
2086
|
+
|
|
2087
|
+
check_bprop = CheckBprop()
|
|
2088
|
+
|
|
2089
|
+
|
|
2090
|
+
class SameTypeShape(PrimitiveWithInfer):
|
|
2091
|
+
"""
|
|
2092
|
+
Checks whether the data type and shape of two tensors are the same.
|
|
2093
|
+
|
|
2094
|
+
Refer to :func:`mindspore.ops.same_type_shape` for more detail.
|
|
2095
|
+
|
|
2096
|
+
Supported Platforms:
|
|
2097
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2098
|
+
|
|
2099
|
+
Examples:
|
|
2100
|
+
>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
|
|
2101
|
+
>>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
|
|
2102
|
+
>>> output = ops.SameTypeShape()(input_x, input_y)
|
|
2103
|
+
>>> print(output)
|
|
2104
|
+
[[2. 2.]
|
|
2105
|
+
[2. 2.]]
|
|
2106
|
+
"""
|
|
2107
|
+
|
|
2108
|
+
@prim_attr_register
|
|
2109
|
+
def __init__(self):
|
|
2110
|
+
"""Initialize Same"""
|
|
2111
|
+
|
|
2112
|
+
def __call__(self, x, y):
|
|
2113
|
+
"""run in PyNative mode"""
|
|
2114
|
+
validator.check_value_type('x', x, Tensor, self.name)
|
|
2115
|
+
validator.check_value_type('y', y, Tensor, self.name)
|
|
2116
|
+
validator.check('x dtype', x.dtype, 'y dtype', y.dtype, validator.EQ, self.name, TypeError)
|
|
2117
|
+
validator.check('x shape', x.shape, 'y shape', y.shape, validator.EQ, self.name)
|
|
2118
|
+
return x
|
|
2119
|
+
|
|
2120
|
+
def __infer__(self, x, y):
|
|
2121
|
+
validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
|
|
2122
|
+
validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
|
|
2123
|
+
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
|
|
2124
|
+
validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
|
|
2125
|
+
return x
|
|
2126
|
+
|
|
2127
|
+
|
|
2128
|
+
same_type_shape_ = SameTypeShape()
|
|
2129
|
+
|
|
2130
|
+
|
|
2131
|
+
def _is_subclass_(type_, dtype):
|
|
2132
|
+
if not isinstance(type_, typing.Type):
|
|
2133
|
+
return False
|
|
2134
|
+
return typing.is_subclass(type_, dtype)
|
|
2135
|
+
|
|
2136
|
+
|
|
2137
|
+
class IsSubClass(PrimitiveWithInfer):
|
|
2138
|
+
"""
|
|
2139
|
+
Checks whether this type is a sub-class of another type.
|
|
2140
|
+
|
|
2141
|
+
Inputs:
|
|
2142
|
+
- **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
|
|
2143
|
+
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
|
|
2144
|
+
|
|
2145
|
+
Outputs:
|
|
2146
|
+
bool, the check result.
|
|
2147
|
+
|
|
2148
|
+
Raises:
|
|
2149
|
+
TypeError: If `sub_type` or `type_` is not a Type.
|
|
2150
|
+
|
|
2151
|
+
Supported Platforms:
|
|
2152
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2153
|
+
|
|
2154
|
+
Examples:
|
|
2155
|
+
>>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc)
|
|
2156
|
+
>>> print(output)
|
|
2157
|
+
True
|
|
2158
|
+
"""
|
|
2159
|
+
|
|
2160
|
+
@prim_attr_register
|
|
2161
|
+
def __init__(self):
|
|
2162
|
+
pass
|
|
2163
|
+
|
|
2164
|
+
def __infer__(self, sub_type, type_):
|
|
2165
|
+
sub_type_t = sub_type['value']
|
|
2166
|
+
type_v = type_['value']
|
|
2167
|
+
|
|
2168
|
+
validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
|
|
2169
|
+
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
|
|
2170
|
+
|
|
2171
|
+
value = _is_subclass_(sub_type_t, type_v)
|
|
2172
|
+
|
|
2173
|
+
out = {'shape': (),
|
|
2174
|
+
'dtype': mstype.type_type,
|
|
2175
|
+
'value': value}
|
|
2176
|
+
return out
|
|
2177
|
+
|
|
2178
|
+
|
|
2179
|
+
issubclass_ = IsSubClass()
|
|
2180
|
+
|
|
2181
|
+
|
|
2182
|
+
class IsInstance(PrimitiveWithInfer):
|
|
2183
|
+
"""
|
|
2184
|
+
Checks whether an object is an instance of a target type.
|
|
2185
|
+
|
|
2186
|
+
Inputs:
|
|
2187
|
+
- **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
|
|
2188
|
+
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
|
|
2189
|
+
|
|
2190
|
+
Outputs:
|
|
2191
|
+
bool, the check result.
|
|
2192
|
+
|
|
2193
|
+
Raises:
|
|
2194
|
+
TypeError: If `type_` is not a Type.
|
|
2195
|
+
|
|
2196
|
+
Supported Platforms:
|
|
2197
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2198
|
+
|
|
2199
|
+
Examples:
|
|
2200
|
+
>>> inst = 1
|
|
2201
|
+
>>> output = ops.IsInstance()(inst, mindspore.int32)
|
|
2202
|
+
>>> print(output)
|
|
2203
|
+
False
|
|
2204
|
+
"""
|
|
2205
|
+
|
|
2206
|
+
@prim_attr_register
|
|
2207
|
+
def __init__(self):
|
|
2208
|
+
pass
|
|
2209
|
+
|
|
2210
|
+
def __infer__(self, inst, type_):
|
|
2211
|
+
sub_type_t = inst['dtype']
|
|
2212
|
+
type_v = type_['value']
|
|
2213
|
+
|
|
2214
|
+
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
|
|
2215
|
+
|
|
2216
|
+
if type_v == mstype.list_:
|
|
2217
|
+
value = isinstance(sub_type_t, list)
|
|
2218
|
+
elif type_v == mstype.tuple_:
|
|
2219
|
+
value = isinstance(sub_type_t, tuple)
|
|
2220
|
+
else:
|
|
2221
|
+
value = _is_subclass_(sub_type_t, type_v)
|
|
2222
|
+
|
|
2223
|
+
out = {'shape': (),
|
|
2224
|
+
'dtype': mstype.type_type,
|
|
2225
|
+
'value': value}
|
|
2226
|
+
return out
|
|
2227
|
+
|
|
2228
|
+
|
|
2229
|
+
class ConvertToAdapterTensor(Primitive):
|
|
2230
|
+
"""
|
|
2231
|
+
Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
|
|
2232
|
+
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
|
2233
|
+
|
|
2234
|
+
Inputs:
|
|
2235
|
+
- **x** (Tensor) - The input tensor.
|
|
2236
|
+
|
|
2237
|
+
Outputs:
|
|
2238
|
+
A tensor, whose type is MSAdapter's Tensor.
|
|
2239
|
+
|
|
2240
|
+
Supported Platforms:
|
|
2241
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2242
|
+
|
|
2243
|
+
Examples:
|
|
2244
|
+
>>> x = Tensor([1, 2 ,3])
|
|
2245
|
+
>>> x = ops.ConvertToAdapterTensor()(x)
|
|
2246
|
+
>>> print(x)
|
|
2247
|
+
[1 2 3]
|
|
2248
|
+
"""
|
|
2249
|
+
|
|
2250
|
+
@prim_attr_register
|
|
2251
|
+
def __init__(self):
|
|
2252
|
+
"""Initialize"""
|
|
2253
|
+
|
|
2254
|
+
def __call__(self, x):
|
|
2255
|
+
"""Run in PyNative mode"""
|
|
2256
|
+
return ms_adapter_registry.tensor(x, cast_tensor=True)
|
|
2257
|
+
|
|
2258
|
+
|
|
2259
|
+
convert_to_adapter_tensor = ConvertToAdapterTensor()
|
|
2260
|
+
|
|
2261
|
+
|
|
2262
|
+
class ConvertToMsTensor(Primitive):
|
|
2263
|
+
"""
|
|
2264
|
+
Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
|
|
2265
|
+
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
|
2266
|
+
|
|
2267
|
+
Inputs:
|
|
2268
|
+
- **x** (Tensor) - The input tensor.
|
|
2269
|
+
|
|
2270
|
+
Outputs:
|
|
2271
|
+
A tensor, whose type is MindSpore's Tensor.
|
|
2272
|
+
|
|
2273
|
+
Supported Platforms:
|
|
2274
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2275
|
+
|
|
2276
|
+
Examples:
|
|
2277
|
+
>>> x = Tensor([1, 2 ,3])
|
|
2278
|
+
>>> x = ops.ConvertToMsTensor()(x)
|
|
2279
|
+
>>> print(x)
|
|
2280
|
+
[1 2 3]
|
|
2281
|
+
"""
|
|
2282
|
+
|
|
2283
|
+
@prim_attr_register
|
|
2284
|
+
def __init__(self):
|
|
2285
|
+
"""Initialize"""
|
|
2286
|
+
|
|
2287
|
+
def __call__(self, x):
|
|
2288
|
+
"""Run in PyNative mode"""
|
|
2289
|
+
if isinstance(x, StubTensor):
|
|
2290
|
+
return StubTensor(stub=x.stub, tensor=x.tensor)
|
|
2291
|
+
return ops.auto_generate.deepcopy(x)
|
|
2292
|
+
|
|
2293
|
+
|
|
2294
|
+
convert_to_ms_tensor = ConvertToMsTensor()
|
|
2295
|
+
|
|
2296
|
+
|
|
2297
|
+
class GetGrad(Primitive):
|
|
2298
|
+
"""
|
|
2299
|
+
Use the position id or Parameter object to get the gradient from the output
|
|
2300
|
+
which returned by the :func:`mindspore.ops.grad`.
|
|
2301
|
+
"""
|
|
2302
|
+
|
|
2303
|
+
@prim_attr_register
|
|
2304
|
+
def __init__(self):
|
|
2305
|
+
"""Initialize ScatterElements"""
|
|
2306
|
+
self.init_prim_io_names(
|
|
2307
|
+
inputs=['gradients', 'x'], outputs=['gradient'])
|
|
2308
|
+
|
|
2309
|
+
def __call__(self, gradients, x):
|
|
2310
|
+
if not isinstance(x, int) and not isinstance(x, Parameter):
|
|
2311
|
+
raise TypeError(
|
|
2312
|
+
f"For `get_grad`, the `x` should be an integer or a Parameter, but got {x}")
|
|
2313
|
+
hash_id = x
|
|
2314
|
+
if isinstance(x, Parameter):
|
|
2315
|
+
hash_id = x.name
|
|
2316
|
+
output = None
|
|
2317
|
+
|
|
2318
|
+
def _get_grad(grads, identifier):
|
|
2319
|
+
if isinstance(grads, tuple):
|
|
2320
|
+
if len(grads) != 2 or identifier != grads[0]:
|
|
2321
|
+
for gradient in grads:
|
|
2322
|
+
_get_grad(gradient, identifier)
|
|
2323
|
+
else:
|
|
2324
|
+
nonlocal output
|
|
2325
|
+
output = grads[1]
|
|
2326
|
+
return
|
|
2327
|
+
|
|
2328
|
+
_get_grad(gradients, hash_id)
|
|
2329
|
+
if output is None:
|
|
2330
|
+
raise RuntimeError(
|
|
2331
|
+
f"Can not find the gradient for position or Parameter {x}")
|
|
2332
|
+
return output
|
|
2333
|
+
|
|
2334
|
+
|
|
2335
|
+
class IsParameter(PrimitiveWithInfer):
|
|
2336
|
+
"""
|
|
2337
|
+
Check if input is `Parameter`
|
|
2338
|
+
"""
|
|
2339
|
+
|
|
2340
|
+
@prim_attr_register
|
|
2341
|
+
def __init__(self):
|
|
2342
|
+
"""Initialize IsParameter"""
|
|
2343
|
+
|
|
2344
|
+
def __call__(self, x):
|
|
2345
|
+
return isinstance(x, Parameter)
|
|
2346
|
+
|
|
2347
|
+
def __infer__(self, x):
|
|
2348
|
+
return {'shape': [],
|
|
2349
|
+
'dtype': mstype.bool_,
|
|
2350
|
+
'value': isinstance(x['dtype'], mstype.RefType)}
|
|
2351
|
+
|
|
2352
|
+
|
|
2353
|
+
class TileSize(Primitive):
|
|
2354
|
+
r"""
|
|
2355
|
+
Tile size for matmul
|
|
2356
|
+
"""
|
|
2357
|
+
|
|
2358
|
+
@prim_attr_register
|
|
2359
|
+
def __init__(self):
|
|
2360
|
+
"""Initialize TileSize"""
|
|
2361
|
+
self.init_prim_io_names(inputs=['shape', 'out_shape', 'ndim'], outputs=['output'])
|
|
2362
|
+
|
|
2363
|
+
def __call__(self, shape, out_shape, ndim):
|
|
2364
|
+
size = [1] * ndim
|
|
2365
|
+
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
|
2366
|
+
if i != j:
|
|
2367
|
+
size[idx] = j
|
|
2368
|
+
return tuple(size)
|
|
2369
|
+
|
|
2370
|
+
|
|
2371
|
+
class GetitemTensorIndexInfo(Primitive):
|
|
2372
|
+
r"""
|
|
2373
|
+
Get getitem tensor index info
|
|
2374
|
+
"""
|
|
2375
|
+
|
|
2376
|
+
@prim_attr_register
|
|
2377
|
+
def __init__(self, is_ascend):
|
|
2378
|
+
"""Initialize GetitemTensorIndexInfo"""
|
|
2379
|
+
self.init_prim_io_names(inputs=['data', 'index'],
|
|
2380
|
+
outputs=["new_index", "tensor_update_types", "tensor_update_args"])
|
|
2381
|
+
validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
|
|
2382
|
+
self.is_ascend = is_ascend
|
|
2383
|
+
|
|
2384
|
+
def __call__(self, data, index):
|
|
2385
|
+
return Tensor_.getitem_index_info(data, index, self.is_ascend)
|
|
2386
|
+
|
|
2387
|
+
|
|
2388
|
+
class SetitemTensorIndexInfo(Primitive):
|
|
2389
|
+
r"""
|
|
2390
|
+
Get setitem tensor index info
|
|
2391
|
+
"""
|
|
2392
|
+
|
|
2393
|
+
@prim_attr_register
|
|
2394
|
+
def __init__(self, is_ascend):
|
|
2395
|
+
"""Initialize GetitemTensorIndexInfo"""
|
|
2396
|
+
self.init_prim_io_names(
|
|
2397
|
+
inputs=['data', 'index', 'value'], outputs=['new_index',
|
|
2398
|
+
'v_transfer_types',
|
|
2399
|
+
'v_transfer_args',
|
|
2400
|
+
'tensor_update_types',
|
|
2401
|
+
'tensor_update_args'])
|
|
2402
|
+
validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
|
|
2403
|
+
self.is_ascend = is_ascend
|
|
2404
|
+
|
|
2405
|
+
def __call__(self, data, index, value):
|
|
2406
|
+
return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
|
|
2407
|
+
|
|
2408
|
+
|
|
2409
|
+
class IsConstant(Primitive):
|
|
2410
|
+
r"""
|
|
2411
|
+
Check if the input is constant
|
|
2412
|
+
"""
|
|
2413
|
+
|
|
2414
|
+
@prim_attr_register
|
|
2415
|
+
def __init__(self):
|
|
2416
|
+
"""Initialize IsConstant"""
|
|
2417
|
+
|
|
2418
|
+
def __call__(self, x):
|
|
2419
|
+
return True
|
|
2420
|
+
|
|
2421
|
+
|
|
2422
|
+
class SelectView(Primitive):
|
|
2423
|
+
r"""
|
|
2424
|
+
Select tensor of view
|
|
2425
|
+
"""
|
|
2426
|
+
|
|
2427
|
+
@prim_attr_register
|
|
2428
|
+
def __init__(self):
|
|
2429
|
+
self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
|
|
2430
|
+
|
|
2431
|
+
|
|
2432
|
+
class CopyWithSlice(Primitive):
|
|
2433
|
+
r"""
|
|
2434
|
+
Copy data to discontinuous tensor
|
|
2435
|
+
"""
|
|
2436
|
+
|
|
2437
|
+
@prim_attr_register
|
|
2438
|
+
def __init__(self):
|
|
2439
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2440
|
+
self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
|
|
2441
|
+
|
|
2442
|
+
|
|
2443
|
+
class FFN(Primitive):
|
|
2444
|
+
r"""
|
|
2445
|
+
The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
|
|
2446
|
+
|
|
2447
|
+
Args:
|
|
2448
|
+
activation (string): The activation type, set to 'fastgelu' or 'gelu'.
|
|
2449
|
+
Only support 'fastgelu' for now. Default: "fastgelu".
|
|
2450
|
+
inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
|
|
2451
|
+
Only support 1 for now. Default: 0.
|
|
2452
|
+
|
|
2453
|
+
Inputs:
|
|
2454
|
+
- **x** (Tensor) - The input tensor with data type of int8, float16.
|
|
2455
|
+
Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
|
|
2456
|
+
- **weight1** (Tensor) - The weight1 tensor with data type of float16.
|
|
2457
|
+
Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
|
|
2458
|
+
- **weight2** (Tensor) - The weight2 tensor with data type of float16.
|
|
2459
|
+
Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
|
|
2460
|
+
- **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
|
|
2461
|
+
Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
|
|
2462
|
+
indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
|
|
2463
|
+
the 2th expert do noting and so on.
|
|
2464
|
+
- **bias1** (Tensor) - The bias1 tensor with data type of float16.
|
|
2465
|
+
Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
|
|
2466
|
+
- **bias2** (Tensor) - The bias2 tensor with data type of float16.
|
|
2467
|
+
Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
|
|
2468
|
+
- **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
|
|
2469
|
+
- **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
|
|
2470
|
+
- **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
|
|
2471
|
+
- **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
|
|
2472
|
+
|
|
2473
|
+
Outputs:
|
|
2474
|
+
Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
|
|
2475
|
+
|
|
2476
|
+
Supported Platforms:
|
|
2477
|
+
``Ascend``
|
|
2478
|
+
|
|
2479
|
+
Examples:
|
|
2480
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2481
|
+
>>> b = 4
|
|
2482
|
+
>>> s = 128
|
|
2483
|
+
>>> h = 1024
|
|
2484
|
+
>>> h_f = 4 * h
|
|
2485
|
+
>>> e = 16
|
|
2486
|
+
>>> x = Tensor(np.random.randn(s, h).astype(np.float16))
|
|
2487
|
+
>>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
|
|
2488
|
+
>>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
|
|
2489
|
+
>>> expert_tokens = Tensor(np.full(e, 8))
|
|
2490
|
+
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2491
|
+
>>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
|
|
2492
|
+
>>> ffn = _inner_ops.FFN("fastgelu", 1)
|
|
2493
|
+
>>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
|
|
2494
|
+
>>> print(output)
|
|
2495
|
+
"""
|
|
2496
|
+
|
|
2497
|
+
@prim_attr_register
|
|
2498
|
+
def __init__(self, activation, inner_precise):
|
|
2499
|
+
"""Initialize FFN."""
|
|
2500
|
+
self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
|
|
2501
|
+
"bias2", "scale", "offset", "deq_scale1", "deq_scale2",
|
|
2502
|
+
"antiquant_scale1", "antiquant_scale2",
|
|
2503
|
+
"antiquant_offset1", "antiquant_offset2"],
|
|
2504
|
+
outputs=["y"])
|
|
2505
|
+
cls_name = self.name
|
|
2506
|
+
validator.check_value_type("activation", activation, [str], cls_name)
|
|
2507
|
+
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2508
|
+
|
|
2509
|
+
|
|
2510
|
+
class _VirtualConverterEnd(PrimitiveWithInfer):
|
|
2511
|
+
"""
|
|
2512
|
+
Auto parallel virtual operator.
|
|
2513
|
+
"""
|
|
2514
|
+
|
|
2515
|
+
@prim_attr_register
|
|
2516
|
+
def __init__(self, input_nums):
|
|
2517
|
+
"""Initialize _VirtualConverterEnd."""
|
|
2518
|
+
self.input_nums = input_nums
|
|
2519
|
+
|
|
2520
|
+
def infer_shape(self, *args):
|
|
2521
|
+
return (args[0][0] * self.input_nums,) + tuple(args[0][1:])
|
|
2522
|
+
|
|
2523
|
+
def infer_dtype(self, *args):
|
|
2524
|
+
return args[0]
|
|
2525
|
+
|
|
2526
|
+
|
|
2527
|
+
class _VirtualConverterBegin(PrimitiveWithInfer):
|
|
2528
|
+
"""
|
|
2529
|
+
Auto parallel virtual operator.
|
|
2530
|
+
"""
|
|
2531
|
+
|
|
2532
|
+
@prim_attr_register
|
|
2533
|
+
def __init__(self, output_nums):
|
|
2534
|
+
"""Initialize _VirtualConverterBegin."""
|
|
2535
|
+
self.output_nums = output_nums
|
|
2536
|
+
|
|
2537
|
+
def infer_shape(self, arg):
|
|
2538
|
+
if self.output_nums == 0:
|
|
2539
|
+
return ValueError("output_nums can\'t be zero.")
|
|
2540
|
+
new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
|
|
2541
|
+
return (new_arg,) * self.output_nums
|
|
2542
|
+
|
|
2543
|
+
def infer_dtype(self, arg):
|
|
2544
|
+
return (arg,) * self.output_nums
|