mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1629 @@
|
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""basic"""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import mindspore.common.dtype as mstype
|
|
23
|
+
from mindspore import context, log as logger
|
|
24
|
+
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
25
|
+
from mindspore.common.seed import _get_graph_seed
|
|
26
|
+
from mindspore.common.tensor import Tensor
|
|
27
|
+
from mindspore.common.initializer import initializer, HeUniform, Uniform
|
|
28
|
+
from mindspore.ops import operations as P
|
|
29
|
+
from mindspore.ops import functional as F
|
|
30
|
+
from mindspore.ops.function.nn_func import interpolate_ext
|
|
31
|
+
from mindspore.ops.auto_generate import unfold_ext
|
|
32
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
33
|
+
from mindspore.ops.primitive import constexpr, Primitive, _primexpr
|
|
34
|
+
from mindspore.common.parameter import Parameter
|
|
35
|
+
from mindspore._extends import cell_attr_register
|
|
36
|
+
from mindspore import _checkparam as Validator
|
|
37
|
+
from mindspore.nn.cell import Cell
|
|
38
|
+
from mindspore.nn.layer.activation import get_activation
|
|
39
|
+
from mindspore.common._decorator import deprecated
|
|
40
|
+
from mindspore.ops.auto_generate import dropout_ext_op, fold_ext
|
|
41
|
+
from mindspore.common.generator import default_generator
|
|
42
|
+
|
|
43
|
+
__all__ = ['Dropout', 'Flatten', 'Dense', 'Linear', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
|
|
44
|
+
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Dropout1d',
|
|
45
|
+
'Dropout2d', 'Dropout3d', 'Upsample', 'Roll', 'Identity', 'Unflatten', 'DropoutExt']
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class L1Regularizer(Cell):
|
|
49
|
+
r"""
|
|
50
|
+
Applies l1 regularization to weights.
|
|
51
|
+
|
|
52
|
+
l1 regularization makes weights sparsity.
|
|
53
|
+
|
|
54
|
+
.. math::
|
|
55
|
+
\text{loss}=\lambda * \text{reduce_sum}(\text{abs}(\omega))
|
|
56
|
+
|
|
57
|
+
where :math:`\lambda` is `scale` .
|
|
58
|
+
|
|
59
|
+
Note:
|
|
60
|
+
scale(regularization factor) should be a number which greater than 0.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
scale (int, float): l1 regularization factor which greater than 0.
|
|
64
|
+
|
|
65
|
+
Inputs:
|
|
66
|
+
- **weights** (Tensor) - The input of L1Regularizer with data type of float16 or float32.
|
|
67
|
+
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
|
68
|
+
|
|
69
|
+
Outputs:
|
|
70
|
+
Tensor, which dtype is higher precision data type between mindspore.float32 and weights dtype,
|
|
71
|
+
and Tensor shape is ()
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
TypeError: If `scale` is neither an int nor float.
|
|
75
|
+
ValueError: If `scale` is not greater than 0.
|
|
76
|
+
ValueError: If `scale` is math.inf or math.nan.
|
|
77
|
+
|
|
78
|
+
Supported Platforms:
|
|
79
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
80
|
+
|
|
81
|
+
Examples:
|
|
82
|
+
>>> import mindspore as ms
|
|
83
|
+
>>> import numpy as np
|
|
84
|
+
>>> scale = 0.5
|
|
85
|
+
>>> net = ms.nn.L1Regularizer(scale)
|
|
86
|
+
>>> weights = ms.Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
|
|
87
|
+
>>> output = net(weights)
|
|
88
|
+
>>> print(output.asnumpy())
|
|
89
|
+
5.0
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, scale):
|
|
93
|
+
"""Initialize L1Regularizer."""
|
|
94
|
+
super(L1Regularizer, self).__init__()
|
|
95
|
+
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
|
96
|
+
if scale <= 0:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"For '{self.cls_name}', the 'scale' must be greater than 0, but got {scale}.")
|
|
99
|
+
if math.isinf(scale) or math.isnan(scale):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
|
|
102
|
+
self.abs = P.Abs()
|
|
103
|
+
self.reduce_sum = P.ReduceSum()
|
|
104
|
+
self.scale = Tensor(scale, dtype=mstype.float32)
|
|
105
|
+
|
|
106
|
+
def construct(self, weights):
|
|
107
|
+
const_utils.check_type_valid(
|
|
108
|
+
F.dtype(weights), mstype.number_type, 'weights')
|
|
109
|
+
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
|
110
|
+
return l1_regularization
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Dropout(Cell):
|
|
114
|
+
r"""
|
|
115
|
+
Dropout layer for the input.
|
|
116
|
+
|
|
117
|
+
Dropout is a means of regularization that reduces overfitting by preventing correlations between neuronal nodes.
|
|
118
|
+
The operator randomly sets some neurons output to 0 according to `p`, which means the probability of discarding
|
|
119
|
+
during training. And the return will be multiplied by :math:`\frac{1}{1-p}` during training.
|
|
120
|
+
During the reasoning, this layer returns the same Tensor as the `x`.
|
|
121
|
+
|
|
122
|
+
This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
|
|
123
|
+
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ and proved to be effective to reduce
|
|
124
|
+
over-fitting and prevents neurons from co-adaptation. See more details in `Improving neural networks by
|
|
125
|
+
preventing co-adaptation of feature detectors
|
|
126
|
+
<https://arxiv.org/pdf/1207.0580.pdf>`_.
|
|
127
|
+
|
|
128
|
+
Note:
|
|
129
|
+
- Each channel will be zeroed out independently on every construct call.
|
|
130
|
+
- Parameter `keep_prob` will be removed in a future version, please use parameter `p` instead.
|
|
131
|
+
Parameter `p` means the probability of the element of the input tensor to be zeroed.
|
|
132
|
+
- Parameter `dtype` will be removed in a future version. It is not recommended to define this parameter.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
keep_prob (float): Deprecated. The keep rate, greater than 0 and less equal than 1.
|
|
136
|
+
E.g. rate=0.9, dropping out 10% of input neurons. Default: ``0.5`` .
|
|
137
|
+
p (Union[float, int, None]): The dropout rate, greater than or equal to 0 and less than 1.
|
|
138
|
+
E.g. rate=0.9, dropping out 90% of input neurons. Default: ``None`` .
|
|
139
|
+
dtype (:class:`mindspore.dtype`): Data type of `input`. Default: ``mstype.float32`` .
|
|
140
|
+
|
|
141
|
+
Inputs:
|
|
142
|
+
- **x** (Tensor) - The input of Dropout with data type of float16 or float32.
|
|
143
|
+
|
|
144
|
+
Outputs:
|
|
145
|
+
Tensor, output tensor with the same shape as the `x`.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
TypeError: If `keep_prob` is not a float.
|
|
149
|
+
TypeError: If the dtype of `p` is not float or int.
|
|
150
|
+
TypeError: If dtype of `x` is not neither float16 nor float32.
|
|
151
|
+
ValueError: If `keep_prob` is not in range (0, 1].
|
|
152
|
+
ValueError: If `p` is not in range [0, 1).
|
|
153
|
+
ValueError: If length of shape of `x` is less than 1.
|
|
154
|
+
|
|
155
|
+
Supported Platforms:
|
|
156
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
>>> import mindspore
|
|
160
|
+
>>> from mindspore import Tensor, nn
|
|
161
|
+
>>> import numpy as np
|
|
162
|
+
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
|
163
|
+
>>> net = nn.Dropout(p=0.2)
|
|
164
|
+
>>> net.set_train()
|
|
165
|
+
>>> output = net(x)
|
|
166
|
+
>>> print(output.shape)
|
|
167
|
+
(2, 2, 3)
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, keep_prob=0.5, p=None, dtype=mstype.float32):
|
|
171
|
+
"""Initialize Dropout."""
|
|
172
|
+
super(Dropout, self).__init__()
|
|
173
|
+
if dtype != mstype.float32:
|
|
174
|
+
logger.warning(
|
|
175
|
+
"This parameter `dtype` will be deleted or invisible in the future. Please don't use it.")
|
|
176
|
+
if p is None:
|
|
177
|
+
logger.warning("For Dropout, this parameter `keep_prob` will be deprecated, please use `p` instead.")
|
|
178
|
+
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
|
179
|
+
if keep_prob <= 0 or keep_prob > 1:
|
|
180
|
+
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
|
|
181
|
+
f"but got {keep_prob}.")
|
|
182
|
+
seed0, seed1 = _get_graph_seed(0, "dropout")
|
|
183
|
+
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
|
184
|
+
else:
|
|
185
|
+
Validator.check_value_type('p', p, [float, int], self.cls_name)
|
|
186
|
+
if p < 0 or p >= 1:
|
|
187
|
+
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1), "
|
|
188
|
+
f"but got {p}.")
|
|
189
|
+
seed0, seed1 = _get_graph_seed(0, "dropout")
|
|
190
|
+
self.dropout = P.Dropout(1.0 - p, seed0, seed1)
|
|
191
|
+
self.p = p
|
|
192
|
+
self.keep_prob = keep_prob
|
|
193
|
+
|
|
194
|
+
def construct(self, x):
|
|
195
|
+
if not self.training or self.keep_prob == 1 or self.p == 0:
|
|
196
|
+
return x
|
|
197
|
+
|
|
198
|
+
out, _ = self.dropout(x)
|
|
199
|
+
return out
|
|
200
|
+
|
|
201
|
+
def extend_repr(self):
|
|
202
|
+
if self.p is None:
|
|
203
|
+
logger.warning("For Dropout, this parameter `keep_prob` will be deprecated, please use `p` instead.")
|
|
204
|
+
return f'keep_prob={self.keep_prob}'
|
|
205
|
+
return f'p={self.p}'
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class DropoutExt(Cell):
|
|
209
|
+
r"""
|
|
210
|
+
Dropout layer for the input.
|
|
211
|
+
|
|
212
|
+
Dropout is a means of regularization that reduces overfitting by preventing correlations between neuronal nodes.
|
|
213
|
+
The operator randomly sets some neurons output to 0 according to `p`, which means the probability of discarding
|
|
214
|
+
during training. And the return will be multiplied by :math:`\frac{1}{1-p}` during training.
|
|
215
|
+
During the reasoning, this layer returns the same Tensor as the `x`.
|
|
216
|
+
|
|
217
|
+
This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
|
|
218
|
+
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ and proved to be effective to reduce
|
|
219
|
+
over-fitting and prevents neurons from co-adaptation. See more details in `Improving neural networks by
|
|
220
|
+
preventing co-adaptation of feature detectors
|
|
221
|
+
<https://arxiv.org/pdf/1207.0580.pdf>`_.
|
|
222
|
+
|
|
223
|
+
Note:
|
|
224
|
+
- Each channel will be zeroed out independently on every construct call.
|
|
225
|
+
- Parameter `p` means the probability of the element of the input tensor to be zeroed.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
p (float): The dropout rate of input neurons, E.g. `p` =0.9, dropping out 90% of input neurons.
|
|
229
|
+
Default: ``0.5`` .
|
|
230
|
+
|
|
231
|
+
Inputs:
|
|
232
|
+
- **x** (Tensor) - The input of Dropout.
|
|
233
|
+
|
|
234
|
+
Outputs:
|
|
235
|
+
Tensor, output tensor with the same shape as the `x`.
|
|
236
|
+
|
|
237
|
+
Raises:
|
|
238
|
+
TypeError: If the dtype of `p` is not float.
|
|
239
|
+
ValueError: If length of shape of `x` is less than 1.
|
|
240
|
+
|
|
241
|
+
Supported Platforms:
|
|
242
|
+
``Ascend``
|
|
243
|
+
|
|
244
|
+
Examples:
|
|
245
|
+
>>> import mindspore
|
|
246
|
+
>>> from mindspore import Tensor, nn
|
|
247
|
+
>>> import numpy as np
|
|
248
|
+
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
|
249
|
+
>>> net = nn.DropoutExt(p=0.2)
|
|
250
|
+
>>> net.set_train()
|
|
251
|
+
>>> output = net(x)
|
|
252
|
+
>>> print(output.shape)
|
|
253
|
+
(2, 2, 3)
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(self, p=0.5):
|
|
257
|
+
"""Initialize DropoutExt."""
|
|
258
|
+
super(DropoutExt, self).__init__()
|
|
259
|
+
self.p = p
|
|
260
|
+
self.generator_step = Tensor(1, mstype.int64)
|
|
261
|
+
|
|
262
|
+
def construct(self, x):
|
|
263
|
+
if not self.training or self.p == 0:
|
|
264
|
+
return x
|
|
265
|
+
|
|
266
|
+
seed, offset = default_generator._step(self.generator_step) # pylint: disable=protected-access
|
|
267
|
+
out, _ = dropout_ext_op(x, self.p, seed, offset)
|
|
268
|
+
return out
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class Dropout1d(Cell):
|
|
272
|
+
r"""
|
|
273
|
+
During training, randomly zeroes entire channels of the input tensor with probability `p`
|
|
274
|
+
from a Bernoulli distribution (For a 3-dimensional tensor with a shape of :math:`(N, C, L)`,
|
|
275
|
+
the channel feature map refers to a 1-dimensional feature map with the shape of :math:`L`).
|
|
276
|
+
|
|
277
|
+
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
|
|
278
|
+
`1D` tensor input[i,j].
|
|
279
|
+
Each channel will be zeroed out independently on every forward call with probability `p` using samples
|
|
280
|
+
from a Bernoulli distribution.
|
|
281
|
+
|
|
282
|
+
The paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
|
|
283
|
+
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ mentioned this technology, And it is proved that
|
|
284
|
+
it can effectively reduce over fitting and prevent neuronal coadaptation.
|
|
285
|
+
For more details, refer to `Improving neural networks by preventing co-adaptation of feature detectors
|
|
286
|
+
<https://arxiv.org/pdf/1207.0580.pdf>`_ .
|
|
287
|
+
|
|
288
|
+
`Dropout1d` can improve the independence between channel feature maps.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
p (float, optional): The dropping probability of a channel, between 0 and 1, e.g. `p` = 0.8,
|
|
292
|
+
which means an 80% chance of being set to 0. Default: ``0.5`` .
|
|
293
|
+
|
|
294
|
+
Inputs:
|
|
295
|
+
- **x** (Tensor) - A tensor with shape :math:`(N, C, L)` or :math:`(C, L)`, where `N` is the batch size,
|
|
296
|
+
`C` is the number of channels, `L` is the feature length. The data type must be int8, int16, int32,
|
|
297
|
+
int64, float16, float32 or float64.
|
|
298
|
+
|
|
299
|
+
Outputs:
|
|
300
|
+
Tensor, has the same shape and data type as `x`.
|
|
301
|
+
|
|
302
|
+
Raises:
|
|
303
|
+
TypeError: If `x` is not a Tensor.
|
|
304
|
+
TypeError: If the data type of `p` is not float.
|
|
305
|
+
ValueError: If `p` is out of the range `[0.0, 1.0]`.
|
|
306
|
+
ValueError: If the shape of `x` is not `2D` or `3D`.
|
|
307
|
+
|
|
308
|
+
Supported Platforms:
|
|
309
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
310
|
+
|
|
311
|
+
Examples:
|
|
312
|
+
>>> import numpy as np
|
|
313
|
+
>>> import mindspore as ms
|
|
314
|
+
>>> op = ms.nn.Dropout1d(p=0.6)
|
|
315
|
+
>>> op.training = True
|
|
316
|
+
>>> a = ms.Tensor(np.ones((3, 3)), ms.float32)
|
|
317
|
+
>>> output = op(a)
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(self, p=0.5):
|
|
321
|
+
"""Initialize Dropout1d."""
|
|
322
|
+
super(Dropout1d, self).__init__()
|
|
323
|
+
Validator.check_value_type('p', p, [float], self.cls_name)
|
|
324
|
+
if p < 0 or p > 1:
|
|
325
|
+
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1], "
|
|
326
|
+
f"but got {p}.")
|
|
327
|
+
self.prob = p
|
|
328
|
+
|
|
329
|
+
def construct(self, x):
|
|
330
|
+
if not self.training or self.prob == 0:
|
|
331
|
+
return x
|
|
332
|
+
|
|
333
|
+
out = F.dropout1d(x, self.prob)
|
|
334
|
+
return out
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class Dropout2d(Cell):
|
|
338
|
+
r"""
|
|
339
|
+
During training, randomly zeroes some channels of the input tensor with probability `p`
|
|
340
|
+
from a Bernoulli distribution (For a 4-dimensional tensor with a shape of :math:`NCHW`,
|
|
341
|
+
the channel feature map refers to a 2-dimensional feature map with the shape of :math:`HW`).
|
|
342
|
+
|
|
343
|
+
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
|
|
344
|
+
`2D` tensor input[i,j].
|
|
345
|
+
Each channel will be zeroed out independently on every forward call with probability `p` using samples
|
|
346
|
+
from a Bernoulli distribution.
|
|
347
|
+
|
|
348
|
+
`Dropout2d` can improve the independence between channel feature maps.
|
|
349
|
+
|
|
350
|
+
Refer to :func:`mindspore.ops.dropout2d` for more details.
|
|
351
|
+
|
|
352
|
+
Supported Platforms:
|
|
353
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
354
|
+
|
|
355
|
+
Examples:
|
|
356
|
+
>>> import mindspore
|
|
357
|
+
>>> from mindspore import Tensor, nn
|
|
358
|
+
>>> import numpy as np
|
|
359
|
+
>>> dropout = nn.Dropout2d(p=0.5)
|
|
360
|
+
>>> x = Tensor(np.ones([2, 1, 2, 3]), mindspore.float32)
|
|
361
|
+
>>> output = dropout(x)
|
|
362
|
+
>>> print(output.shape)
|
|
363
|
+
(2, 1, 2, 3)
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
def __init__(self, p=0.5):
|
|
367
|
+
"""Initialize Dropout2d."""
|
|
368
|
+
super(Dropout2d, self).__init__()
|
|
369
|
+
Validator.check_value_type('p', p, [float], self.cls_name)
|
|
370
|
+
if p < 0 or p > 1:
|
|
371
|
+
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1], "
|
|
372
|
+
f"but got {p}.")
|
|
373
|
+
self.keep_prob = 1.0 - p
|
|
374
|
+
self.dropout2d = P.Dropout2D(self.keep_prob)
|
|
375
|
+
|
|
376
|
+
def construct(self, x):
|
|
377
|
+
if not self.training or self.keep_prob == 1:
|
|
378
|
+
return x
|
|
379
|
+
|
|
380
|
+
out, _ = self.dropout2d(x)
|
|
381
|
+
return out
|
|
382
|
+
|
|
383
|
+
def extend_repr(self):
|
|
384
|
+
return f"p={self.keep_prob}"
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class Dropout3d(Cell):
|
|
388
|
+
r"""
|
|
389
|
+
During training, randomly zeroes some channels of the input tensor
|
|
390
|
+
with probability `p` from a Bernoulli distribution (For a 5-dimensional tensor with
|
|
391
|
+
a shape of :math:`NCDHW`, the channel feature map refers to a 3-dimensional feature
|
|
392
|
+
map with a shape of :math:`DHW`).
|
|
393
|
+
|
|
394
|
+
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
|
|
395
|
+
`3D` tensor input[i,j].
|
|
396
|
+
Each channel will be zeroed out independently on every forward call which based on Bernoulli distribution
|
|
397
|
+
probability `p`.
|
|
398
|
+
|
|
399
|
+
`Dropout3d` can improve the independence between channel feature maps.
|
|
400
|
+
|
|
401
|
+
Refer to :func:`mindspore.ops.dropout3d` for more details.
|
|
402
|
+
|
|
403
|
+
Supported Platforms:
|
|
404
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
405
|
+
|
|
406
|
+
Examples:
|
|
407
|
+
>>> import mindspore
|
|
408
|
+
>>> from mindspore import Tensor, nn
|
|
409
|
+
>>> import numpy as np
|
|
410
|
+
>>> dropout = nn.Dropout3d(p=0.5)
|
|
411
|
+
>>> x = Tensor(np.ones([2, 1, 2, 1, 2]), mindspore.float32)
|
|
412
|
+
>>> output = dropout(x)
|
|
413
|
+
>>> print(output.shape)
|
|
414
|
+
(2, 1, 2, 1, 2)
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def __init__(self, p=0.5):
|
|
418
|
+
"""Initialize Dropout3d."""
|
|
419
|
+
super(Dropout3d, self).__init__()
|
|
420
|
+
Validator.check_value_type('p', p, [float], self.cls_name)
|
|
421
|
+
if p < 0 or p > 1:
|
|
422
|
+
raise ValueError(f"For '{self.cls_name}', the 'p' must be a number in range [0, 1], "
|
|
423
|
+
f"but got {p}.")
|
|
424
|
+
self.keep_prob = 1.0 - p
|
|
425
|
+
self.dropout3d = P.Dropout3D(self.keep_prob)
|
|
426
|
+
|
|
427
|
+
def construct(self, x):
|
|
428
|
+
if not self.training or self.keep_prob == 1:
|
|
429
|
+
return x
|
|
430
|
+
|
|
431
|
+
out, _ = self.dropout3d(x)
|
|
432
|
+
return out
|
|
433
|
+
|
|
434
|
+
def extend_repr(self):
|
|
435
|
+
return f'p={self.keep_prob}'
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class Upsample(Cell):
|
|
439
|
+
r"""
|
|
440
|
+
For details, please refer to :func:`mindspore.ops.interpolate`.
|
|
441
|
+
|
|
442
|
+
Supported Platforms:
|
|
443
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
444
|
+
|
|
445
|
+
Examples:
|
|
446
|
+
>>> import mindspore as ms
|
|
447
|
+
>>> x = ms.Tensor([[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]])
|
|
448
|
+
>>> upsample = ms.nn.Upsample(size=(5, 5))
|
|
449
|
+
>>> out = upsample(x)
|
|
450
|
+
>>> print(x.asnumpy())
|
|
451
|
+
[[[[1. 2. 3. 4.]
|
|
452
|
+
[5. 6. 7. 8.]]]]
|
|
453
|
+
>>> print(out.asnumpy())
|
|
454
|
+
[[[[1. 1. 2. 3. 4.]
|
|
455
|
+
[1. 1. 2. 3. 4.]
|
|
456
|
+
[1. 1. 2. 3. 4.]
|
|
457
|
+
[5. 5. 6. 7. 8.]
|
|
458
|
+
[5. 5. 6. 7. 8.]]]]
|
|
459
|
+
>>> print(out.shape)
|
|
460
|
+
(1, 1, 5, 5)
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None):
|
|
464
|
+
"""Initialize Upsample."""
|
|
465
|
+
super(Upsample, self).__init__()
|
|
466
|
+
self.size = size
|
|
467
|
+
self.scale_factor = scale_factor
|
|
468
|
+
self.mode = mode
|
|
469
|
+
self.align_corners = align_corners
|
|
470
|
+
self.recompute_scale_factor = recompute_scale_factor
|
|
471
|
+
|
|
472
|
+
def construct(self, x):
|
|
473
|
+
out = F.interpolate(x, self.size, self.scale_factor, self.mode,
|
|
474
|
+
self.align_corners, self.recompute_scale_factor)
|
|
475
|
+
return out
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class UpsampleExt(Cell):
|
|
479
|
+
r"""
|
|
480
|
+
For details, please refer to :func:`mindspore.mint.nn.functional.interpolate`.
|
|
481
|
+
|
|
482
|
+
Supported Platforms:
|
|
483
|
+
``Ascend``
|
|
484
|
+
|
|
485
|
+
Examples:
|
|
486
|
+
>>> import mindspore as ms
|
|
487
|
+
>>> from mindspore import nn
|
|
488
|
+
>>> x = ms.Tensor([[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]])
|
|
489
|
+
>>> upsample = nn.UpsampleExt(size=(5, 5))
|
|
490
|
+
>>> out = upsample(x)
|
|
491
|
+
>>> print(x.asnumpy())
|
|
492
|
+
[[[[1. 2. 3. 4.]
|
|
493
|
+
[5. 6. 7. 8.]]]]
|
|
494
|
+
>>> print(out.asnumpy())
|
|
495
|
+
[[[[1. 1. 2. 3. 4.]
|
|
496
|
+
[1. 1. 2. 3. 4.]
|
|
497
|
+
[1. 1. 2. 3. 4.]
|
|
498
|
+
[5. 5. 6. 7. 8.]
|
|
499
|
+
[5. 5. 6. 7. 8.]]]]
|
|
500
|
+
>>> print(out.shape)
|
|
501
|
+
(1, 1, 5, 5)
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None):
|
|
505
|
+
"""Initialize Upsample."""
|
|
506
|
+
super(UpsampleExt, self).__init__()
|
|
507
|
+
self.size = size
|
|
508
|
+
self.scale_factor = scale_factor
|
|
509
|
+
self.mode = mode
|
|
510
|
+
self.align_corners = align_corners
|
|
511
|
+
self.recompute_scale_factor = recompute_scale_factor
|
|
512
|
+
|
|
513
|
+
def construct(self, input):
|
|
514
|
+
out = interpolate_ext(input, self.size, self.scale_factor, self.mode,
|
|
515
|
+
self.align_corners, self.recompute_scale_factor)
|
|
516
|
+
return out
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class Flatten(Cell):
|
|
520
|
+
r"""
|
|
521
|
+
Flatten the input Tensor along dimensions from `start_dim` to `end_dim`.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
start_dim (int, optional): The first dimension to flatten. Default: ``1`` .
|
|
525
|
+
end_dim (int, optional): The last dimension to flatten. Default: ``-1`` .
|
|
526
|
+
|
|
527
|
+
Inputs:
|
|
528
|
+
- **x** (Tensor) - The input Tensor to be flattened.
|
|
529
|
+
|
|
530
|
+
Outputs:
|
|
531
|
+
Tensor. If no dimensions are flattened, returns the original `x`, otherwise return the flattened Tensor.
|
|
532
|
+
If `x` is a 0-dimensional Tensor, a 1-dimensional Tensor will be returned.
|
|
533
|
+
|
|
534
|
+
Raises:
|
|
535
|
+
TypeError: If `x` is not a Tensor.
|
|
536
|
+
TypeError: If `start_dim` or `end_dim` is not int.
|
|
537
|
+
ValueError: If `start_dim` is greater than `end_dim` after canonicalized.
|
|
538
|
+
ValueError: If `start_dim` or `end_dim` is not in range of [-x.dim, x.dim-1]. For example, the default values
|
|
539
|
+
are used for the args and the input is a 0-dimensional or 1-dimensional Tensor.
|
|
540
|
+
|
|
541
|
+
Supported Platforms:
|
|
542
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
543
|
+
|
|
544
|
+
Examples:
|
|
545
|
+
>>> import mindspore
|
|
546
|
+
>>> from mindspore import Tensor, nn
|
|
547
|
+
>>> import numpy as np
|
|
548
|
+
>>> x = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
|
549
|
+
>>> net = nn.Flatten()
|
|
550
|
+
>>> output = net(x)
|
|
551
|
+
>>> print(output)
|
|
552
|
+
[[1.2 1.2 2.1 2.1]
|
|
553
|
+
[2.2 2.2 3.2 3.2]]
|
|
554
|
+
>>> print(f"before flatten the x shape is {x.shape}")
|
|
555
|
+
before flatten the x shape is (2, 2, 2)
|
|
556
|
+
>>> print(f"after flatten the output shape is {output.shape}")
|
|
557
|
+
after flatten the output shape is (2, 4)
|
|
558
|
+
"""
|
|
559
|
+
|
|
560
|
+
def __init__(self, start_dim=1, end_dim=-1):
|
|
561
|
+
"""Initialize Flatten."""
|
|
562
|
+
super(Flatten, self).__init__()
|
|
563
|
+
self.start_dim = start_dim
|
|
564
|
+
self.end_dim = end_dim
|
|
565
|
+
|
|
566
|
+
def check_axis_valid(self, axis, ndim):
|
|
567
|
+
if axis < -ndim or axis >= ndim:
|
|
568
|
+
raise ValueError("'start_dim' or 'end_dim' out of range.")
|
|
569
|
+
|
|
570
|
+
def construct(self, x):
|
|
571
|
+
x_rank = F.rank(x)
|
|
572
|
+
ndim = x_rank if x_rank != 0 else 1
|
|
573
|
+
self.check_axis_valid(self.start_dim, ndim)
|
|
574
|
+
self.check_axis_valid(self.end_dim, ndim)
|
|
575
|
+
return F.flatten(x, start_dim=self.start_dim, end_dim=self.end_dim)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class Identity(Cell):
|
|
579
|
+
r"""
|
|
580
|
+
A placeholder identity operator that returns the same as input.
|
|
581
|
+
|
|
582
|
+
Inputs:
|
|
583
|
+
- **x** (Any) - The input of Identity.
|
|
584
|
+
|
|
585
|
+
Outputs:
|
|
586
|
+
The same as `x`.
|
|
587
|
+
|
|
588
|
+
Supported Platforms:
|
|
589
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
590
|
+
|
|
591
|
+
Examples:
|
|
592
|
+
>>> import mindspore
|
|
593
|
+
>>> from mindspore import Tensor, nn
|
|
594
|
+
>>> import numpy as np
|
|
595
|
+
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
|
|
596
|
+
>>> net = nn.Identity()
|
|
597
|
+
>>> output = net(x)
|
|
598
|
+
>>> print(output)
|
|
599
|
+
[1 2 3 4]
|
|
600
|
+
"""
|
|
601
|
+
|
|
602
|
+
def __init__(self):
|
|
603
|
+
"""Initialize Identity."""
|
|
604
|
+
super(Identity, self).__init__()
|
|
605
|
+
|
|
606
|
+
def construct(self, x):
|
|
607
|
+
return x
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class Dense(Cell):
|
|
611
|
+
r"""
|
|
612
|
+
The dense connected layer.
|
|
613
|
+
|
|
614
|
+
Applies dense connected layer for the input. This layer implements the operation as:
|
|
615
|
+
|
|
616
|
+
.. math::
|
|
617
|
+
\text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
|
|
618
|
+
|
|
619
|
+
where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
|
|
620
|
+
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
|
|
621
|
+
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
622
|
+
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
in_channels (int): The number of channels in the input space.
|
|
626
|
+
out_channels (int): The number of channels in the output space.
|
|
627
|
+
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
628
|
+
is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
629
|
+
weight will be initialized using HeUniform.
|
|
630
|
+
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
631
|
+
same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
632
|
+
bias will be initialized using Uniform.
|
|
633
|
+
has_bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
|
|
634
|
+
activation (Union[str, Cell, Primitive, None]): activate function applied to the output of the fully connected
|
|
635
|
+
layer. Both activation name, e.g. 'relu', and mindspore activation function, e.g. mindspore.ops.ReLU(),
|
|
636
|
+
are supported. Default: ``None`` .
|
|
637
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
638
|
+
|
|
639
|
+
Inputs:
|
|
640
|
+
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
|
641
|
+
to :math:`in\_channels` in `Inputs`.
|
|
642
|
+
|
|
643
|
+
Outputs:
|
|
644
|
+
Tensor of shape :math:`(*, out\_channels)`.
|
|
645
|
+
|
|
646
|
+
Raises:
|
|
647
|
+
TypeError: If `in_channels` or `out_channels` is not an int.
|
|
648
|
+
TypeError: If `has_bias` is not a bool.
|
|
649
|
+
TypeError: If `activation` is not one of str, Cell, Primitive, None.
|
|
650
|
+
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
|
651
|
+
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
|
|
652
|
+
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
653
|
+
or shape[0] of `bias_init` is not equal to `out_channels`.
|
|
654
|
+
|
|
655
|
+
Supported Platforms:
|
|
656
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
657
|
+
|
|
658
|
+
Examples:
|
|
659
|
+
>>> import mindspore
|
|
660
|
+
>>> from mindspore import Tensor, nn
|
|
661
|
+
>>> import numpy as np
|
|
662
|
+
>>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
|
663
|
+
>>> net = nn.Dense(3, 4)
|
|
664
|
+
>>> output = net(x)
|
|
665
|
+
>>> print(output.shape)
|
|
666
|
+
(2, 4)
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
@cell_attr_register(attrs=['has_bias', 'activation'])
|
|
670
|
+
def __init__(self,
|
|
671
|
+
in_channels,
|
|
672
|
+
out_channels,
|
|
673
|
+
weight_init=None,
|
|
674
|
+
bias_init=None,
|
|
675
|
+
has_bias=True,
|
|
676
|
+
activation=None,
|
|
677
|
+
dtype=mstype.float32):
|
|
678
|
+
"""Initialize Dense."""
|
|
679
|
+
super(Dense, self).__init__()
|
|
680
|
+
self.in_channels = Validator.check_positive_int(
|
|
681
|
+
in_channels, "in_channels", self.cls_name)
|
|
682
|
+
self.out_channels = Validator.check_positive_int(
|
|
683
|
+
out_channels, "out_channels", self.cls_name)
|
|
684
|
+
self.has_bias = Validator.check_bool(
|
|
685
|
+
has_bias, "has_bias", self.cls_name)
|
|
686
|
+
self.reshape = P.Reshape()
|
|
687
|
+
self.shape_op = P.Shape()
|
|
688
|
+
|
|
689
|
+
if isinstance(weight_init, Tensor):
|
|
690
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
|
691
|
+
weight_init.shape[1] != in_channels:
|
|
692
|
+
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
|
693
|
+
f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
|
|
694
|
+
f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
|
695
|
+
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
|
696
|
+
if weight_init is None:
|
|
697
|
+
weight_init = HeUniform(math.sqrt(5))
|
|
698
|
+
self.weight = Parameter(initializer(
|
|
699
|
+
weight_init, [out_channels, in_channels], dtype=dtype), name="weight")
|
|
700
|
+
|
|
701
|
+
self.bias = None
|
|
702
|
+
if self.has_bias:
|
|
703
|
+
if isinstance(bias_init, Tensor):
|
|
704
|
+
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
|
705
|
+
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
|
706
|
+
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
707
|
+
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
708
|
+
if bias_init is None:
|
|
709
|
+
bound = 1 / math.sqrt(in_channels)
|
|
710
|
+
bias_init = Uniform(scale=bound)
|
|
711
|
+
self.bias = Parameter(initializer(
|
|
712
|
+
bias_init, [out_channels], dtype=dtype), name="bias")
|
|
713
|
+
self.bias_add = P.BiasAdd()
|
|
714
|
+
|
|
715
|
+
self.matmul = P.MatMul(transpose_b=True)
|
|
716
|
+
self.activation = get_activation(activation) if isinstance(
|
|
717
|
+
activation, str) else activation
|
|
718
|
+
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
|
719
|
+
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
|
|
720
|
+
f"{type(activation).__name__}.")
|
|
721
|
+
self.activation_flag = self.activation is not None
|
|
722
|
+
|
|
723
|
+
def construct(self, x):
|
|
724
|
+
x_shape = self.shape_op(x)
|
|
725
|
+
if len(x_shape) != 2:
|
|
726
|
+
x = self.reshape(x, (-1, x_shape[-1]))
|
|
727
|
+
x = self.matmul(x, self.weight)
|
|
728
|
+
if self.has_bias:
|
|
729
|
+
x = self.bias_add(x, self.bias)
|
|
730
|
+
if self.activation_flag:
|
|
731
|
+
x = self.activation(x)
|
|
732
|
+
if len(x_shape) != 2:
|
|
733
|
+
out_shape = x_shape[:-1] + (F.shape(x)[-1],)
|
|
734
|
+
x = self.reshape(x, out_shape)
|
|
735
|
+
return x
|
|
736
|
+
|
|
737
|
+
def extend_repr(self):
|
|
738
|
+
s = f'input_channels={self.in_channels}, output_channels={self.out_channels}'
|
|
739
|
+
if self.has_bias:
|
|
740
|
+
s += f', has_bias={self.has_bias}'
|
|
741
|
+
if self.activation_flag:
|
|
742
|
+
s += f', activation={self.activation}'
|
|
743
|
+
return s
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
class Linear(Cell):
|
|
747
|
+
r"""
|
|
748
|
+
The linear connected layer.
|
|
749
|
+
|
|
750
|
+
Applies linear connected layer for the input. This layer implements the operation as:
|
|
751
|
+
|
|
752
|
+
.. math::
|
|
753
|
+
\text{outputs} = X * kernel + bias
|
|
754
|
+
|
|
755
|
+
where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
|
|
756
|
+
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
757
|
+
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
in_features (int): The number of features in the input space.
|
|
761
|
+
out_features (int): The number of features in the output space.
|
|
762
|
+
bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
|
|
763
|
+
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
764
|
+
is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
765
|
+
weight will be initialized using HeUniform.
|
|
766
|
+
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
767
|
+
same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
768
|
+
bias will be initialized using Uniform.
|
|
769
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
|
|
770
|
+
|
|
771
|
+
Inputs:
|
|
772
|
+
- **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
|
|
773
|
+
to :math:`in\_features` in `Inputs`.
|
|
774
|
+
|
|
775
|
+
Outputs:
|
|
776
|
+
Tensor of shape :math:`(*, out\_features)`.
|
|
777
|
+
|
|
778
|
+
Raises:
|
|
779
|
+
TypeError: If `in_features` or `out_features` is not an int.
|
|
780
|
+
TypeError: If `bias` is not a bool.
|
|
781
|
+
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
|
782
|
+
is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
|
|
783
|
+
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
784
|
+
or shape[0] of `bias_init` is not equal to `out_features`.
|
|
785
|
+
|
|
786
|
+
Supported Platforms:
|
|
787
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
788
|
+
|
|
789
|
+
Examples:
|
|
790
|
+
>>> import mindspore
|
|
791
|
+
>>> from mindspore import Tensor
|
|
792
|
+
>>> from mindspore import nn
|
|
793
|
+
>>> import numpy as np
|
|
794
|
+
>>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
|
795
|
+
>>> net = nn.mint.nn.Linear(3, 4)
|
|
796
|
+
>>> output = net(x)
|
|
797
|
+
>>> print(output.shape)
|
|
798
|
+
(2, 4)
|
|
799
|
+
"""
|
|
800
|
+
|
|
801
|
+
@cell_attr_register(attrs=['has_bias'])
|
|
802
|
+
def __init__(self,
|
|
803
|
+
in_features,
|
|
804
|
+
out_features,
|
|
805
|
+
bias=True,
|
|
806
|
+
weight_init=None,
|
|
807
|
+
bias_init=None,
|
|
808
|
+
dtype=None):
|
|
809
|
+
"""Initialize Linear."""
|
|
810
|
+
super(Linear, self).__init__()
|
|
811
|
+
self.in_features = Validator.check_positive_int(
|
|
812
|
+
in_features, "in_features", self.cls_name)
|
|
813
|
+
self.out_features = Validator.check_positive_int(
|
|
814
|
+
out_features, "out_features", self.cls_name)
|
|
815
|
+
self.has_bias = Validator.check_bool(
|
|
816
|
+
bias, "has_bias", self.cls_name)
|
|
817
|
+
self.dense = P.Dense()
|
|
818
|
+
if dtype is None:
|
|
819
|
+
dtype = mstype.float32
|
|
820
|
+
if isinstance(weight_init, Tensor):
|
|
821
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
|
|
822
|
+
weight_init.shape[1] != in_features:
|
|
823
|
+
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
|
824
|
+
f"be equal to 2, and the first dim must be equal to 'out_features', and the "
|
|
825
|
+
f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
|
|
826
|
+
f"'out_features': {out_features}, 'in_features': {in_features}.")
|
|
827
|
+
if weight_init is None:
|
|
828
|
+
weight_init = HeUniform(math.sqrt(5))
|
|
829
|
+
self.weight = Parameter(initializer(
|
|
830
|
+
weight_init, [out_features, in_features], dtype=dtype), name="weight")
|
|
831
|
+
|
|
832
|
+
self.bias = None
|
|
833
|
+
if self.has_bias:
|
|
834
|
+
if isinstance(bias_init, Tensor):
|
|
835
|
+
if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
|
|
836
|
+
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
|
837
|
+
f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
|
|
838
|
+
f"'bias_init': {bias_init}, 'out_features': {out_features}.")
|
|
839
|
+
if bias_init is None:
|
|
840
|
+
bound = 1 / math.sqrt(in_features)
|
|
841
|
+
bias_init = Uniform(scale=bound)
|
|
842
|
+
self.bias = Parameter(initializer(
|
|
843
|
+
bias_init, [out_features], dtype=dtype), name="bias")
|
|
844
|
+
|
|
845
|
+
def construct(self, x):
|
|
846
|
+
x = self.dense(x, self.weight, self.bias)
|
|
847
|
+
return x
|
|
848
|
+
|
|
849
|
+
def extend_repr(self):
|
|
850
|
+
s = f'input_features={self.in_features}, output_features={self.out_features}'
|
|
851
|
+
if self.has_bias:
|
|
852
|
+
s += f', has_bias={self.has_bias}'
|
|
853
|
+
return s
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
@constexpr
|
|
857
|
+
def _is_equal_one(x):
|
|
858
|
+
if x is None:
|
|
859
|
+
return False
|
|
860
|
+
return F.equal(F.reduce_mean(x), 1.0)
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
@constexpr
|
|
864
|
+
def _dtype_check(x_dtype, prim_name=None):
|
|
865
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
866
|
+
if x_dtype not in [mstype.float32, mstype.float16]:
|
|
867
|
+
raise TypeError(
|
|
868
|
+
f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
@constexpr
|
|
872
|
+
def _is_float_dtype(dtype):
|
|
873
|
+
if dtype in [mstype.float32, mstype.float16]:
|
|
874
|
+
return True
|
|
875
|
+
return False
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
@constexpr
|
|
879
|
+
def _need_reduce_all(axis):
|
|
880
|
+
if axis == ():
|
|
881
|
+
return True
|
|
882
|
+
return False
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
class ClipByNorm(Cell):
|
|
886
|
+
r"""
|
|
887
|
+
Clips tensor values to a maximum :math:`L_2`-norm.
|
|
888
|
+
|
|
889
|
+
The output of this layer remains the same if the :math:`L_2`-norm of the input tensor
|
|
890
|
+
is not greater than the argument clip_norm. Otherwise the tensor will be normalized as:
|
|
891
|
+
|
|
892
|
+
.. math::
|
|
893
|
+
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
|
|
894
|
+
|
|
895
|
+
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
axis (Union[None, int, tuple(int)]): Compute the L2-norm along the Specific dimension.
|
|
899
|
+
Default: ``None`` , all dimensions to calculate.
|
|
900
|
+
|
|
901
|
+
Inputs:
|
|
902
|
+
- **x** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
|
|
903
|
+
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
|
|
904
|
+
Or a tensor shape can be broadcast to input `x` shape.
|
|
905
|
+
|
|
906
|
+
Outputs:
|
|
907
|
+
Tensor, clipped tensor with the same shape as the `x`, whose type is float32.
|
|
908
|
+
|
|
909
|
+
Raises:
|
|
910
|
+
TypeError: If `axis` is not one of None, int, tuple.
|
|
911
|
+
TypeError: If dtype of `x` is neither float32 nor float16.
|
|
912
|
+
|
|
913
|
+
Supported Platforms:
|
|
914
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
915
|
+
|
|
916
|
+
Examples:
|
|
917
|
+
>>> import mindspore
|
|
918
|
+
>>> from mindspore import Tensor, nn
|
|
919
|
+
>>> import numpy as np
|
|
920
|
+
>>> net = nn.ClipByNorm()
|
|
921
|
+
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
|
922
|
+
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
|
923
|
+
>>> output = net(x, clip_norm)
|
|
924
|
+
>>> print(output.shape)
|
|
925
|
+
(4, 16)
|
|
926
|
+
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
def __init__(self, axis=None):
|
|
930
|
+
"""Initialize ClipByNorm."""
|
|
931
|
+
super(ClipByNorm, self).__init__()
|
|
932
|
+
self.clip_by_norm = inner.ClipByNorm(axis)
|
|
933
|
+
|
|
934
|
+
def construct(self, x, clip_norm):
|
|
935
|
+
values_clip = self.clip_by_norm(x, clip_norm)
|
|
936
|
+
return values_clip
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
class Norm(Cell):
|
|
940
|
+
r"""
|
|
941
|
+
The Norm class will be deprecated in the future,
|
|
942
|
+
this function can be replaced by :func:`ops.norm`
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
@deprecated("2.0", "ops.norm", False)
|
|
946
|
+
def __init__(self, axis=(), keep_dims=False):
|
|
947
|
+
"""Initialize Norm."""
|
|
948
|
+
super(Norm, self).__init__()
|
|
949
|
+
Validator.check_value_type(
|
|
950
|
+
"keep_dims", keep_dims, [bool], self.cls_name)
|
|
951
|
+
self.axis = axis
|
|
952
|
+
self.keep_dims = keep_dims
|
|
953
|
+
self.reduce_sum = P.ReduceSum(True)
|
|
954
|
+
self.sqrt = P.Sqrt()
|
|
955
|
+
self.squeeze = P.Squeeze(self.axis)
|
|
956
|
+
|
|
957
|
+
def construct(self, x):
|
|
958
|
+
x = self.sqrt(self.reduce_sum(F.square(x), self.axis))
|
|
959
|
+
|
|
960
|
+
if not self.keep_dims:
|
|
961
|
+
x = self.squeeze(x)
|
|
962
|
+
return x
|
|
963
|
+
|
|
964
|
+
def extend_repr(self):
|
|
965
|
+
return f'axis={self.axis}, keep_dims={self.keep_dims}'
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
class OneHot(Cell):
|
|
969
|
+
"""
|
|
970
|
+
The OneHot class will be deprecated in the future,
|
|
971
|
+
this function can be replaced by :func:`ops.one_hot`
|
|
972
|
+
"""
|
|
973
|
+
|
|
974
|
+
@deprecated("2.0", "ops.one_hot", False)
|
|
975
|
+
def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
|
|
976
|
+
"""Initialize OneHot."""
|
|
977
|
+
super(OneHot, self).__init__()
|
|
978
|
+
self.onehot = P.OneHot(axis)
|
|
979
|
+
self.depth = depth
|
|
980
|
+
self.dtype = dtype
|
|
981
|
+
self.on_value = on_value
|
|
982
|
+
self.off_value = off_value
|
|
983
|
+
|
|
984
|
+
def construct(self, indices):
|
|
985
|
+
return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
class Pad(Cell):
|
|
989
|
+
r"""
|
|
990
|
+
Pads the input tensor according to the paddings and mode.
|
|
991
|
+
|
|
992
|
+
Args:
|
|
993
|
+
paddings (tuple): The shape of parameter `paddings` is :math:`(N, 2)` . N is the rank of input data. All
|
|
994
|
+
elements of paddings are int type. For `D` th dimension of the `x`, paddings[D, 0] indicates how many
|
|
995
|
+
sizes to be extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how
|
|
996
|
+
many sizes to be extended behind of the `D` th dimension of the input tensor. The padded size of each
|
|
997
|
+
dimension D of the output is: :math:`paddings[D, 0] + input\_x.dim\_size(D) + paddings[D, 1]`,
|
|
998
|
+
e.g.:
|
|
999
|
+
|
|
1000
|
+
.. code-block::
|
|
1001
|
+
|
|
1002
|
+
mode = "CONSTANT".
|
|
1003
|
+
paddings = [[1,1], [2,2]].
|
|
1004
|
+
x = [[1,2,3], [4,5,6], [7,8,9]].
|
|
1005
|
+
# The above can be seen: 1st dimension of `x` is 3, 2nd dimension of `x` is 3.
|
|
1006
|
+
# Substitute into the formula to get:
|
|
1007
|
+
# 1st dimension of output is paddings[0][0] + 3 + paddings[0][1] = 1 + 3 + 1 = 5.
|
|
1008
|
+
# 2nd dimension of output is paddings[1][0] + 3 + paddings[1][1] = 2 + 3 + 2 = 7.
|
|
1009
|
+
# So the shape of output is (5, 7).
|
|
1010
|
+
|
|
1011
|
+
mode (str): Specifies padding mode. The optional values are ``"CONSTANT"`` , ``"REFLECT"`` , ``"SYMMETRIC"`` .
|
|
1012
|
+
Default: ``"CONSTANT"`` .
|
|
1013
|
+
|
|
1014
|
+
Inputs:
|
|
1015
|
+
- **x** (Tensor) - The input tensor.
|
|
1016
|
+
|
|
1017
|
+
Outputs:
|
|
1018
|
+
Tensor, the tensor after padding.
|
|
1019
|
+
|
|
1020
|
+
- If `mode` is "CONSTANT", it fills the edge with 0, regardless of the values of the `x`.
|
|
1021
|
+
If the `x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
|
|
1022
|
+
Outputs is [[0,0,0,0,0,0,0], [0,0,1,2,3,0,0], [0,0,4,5,6,0,0], [0,0,7,8,9,0,0], [0,0,0,0,0,0,0]].
|
|
1023
|
+
- If `mode` is "REFLECT", it uses a way of symmetrical copying through the axis of symmetry to fill in.
|
|
1024
|
+
If the `x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
|
|
1025
|
+
Outputs is [[6,5,4,5,6,5,4], [3,2,1,2,3,2,1], [6,5,4,5,6,5,4], [9,8,7,8,9,8,7], [6,5,4,5,6,5,4]].
|
|
1026
|
+
- If `mode` is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
|
|
1027
|
+
according to the symmetry axis, except that it includes the symmetry axis. If the `x`
|
|
1028
|
+
is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the Outputs is
|
|
1029
|
+
[[2,1,1,2,3,3,2], [2,1,1,2,3,3,2], [5,4,4,5,6,6,5], [8,7,7,8,9,9,8], [8,7,7,8,9,9,8]].
|
|
1030
|
+
|
|
1031
|
+
Raises:
|
|
1032
|
+
TypeError: If `paddings` is not a tuple.
|
|
1033
|
+
ValueError: If length of `paddings` is more than 4 or its shape is not :math:`(N, 2)` .
|
|
1034
|
+
ValueError: If `mode` is not one of ``"CONSTANT"``, ``"REFLECT"``, ``"SYMMETRIC"``.
|
|
1035
|
+
|
|
1036
|
+
Supported Platforms:
|
|
1037
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1038
|
+
|
|
1039
|
+
Examples:
|
|
1040
|
+
>>> import mindspore
|
|
1041
|
+
>>> from mindspore import Tensor, nn, ops
|
|
1042
|
+
>>> import numpy as np
|
|
1043
|
+
>>> # If `mode` is "CONSTANT"
|
|
1044
|
+
>>> class Net(nn.Cell):
|
|
1045
|
+
... def __init__(self):
|
|
1046
|
+
... super(Net, self).__init__()
|
|
1047
|
+
... self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="CONSTANT")
|
|
1048
|
+
... def construct(self, x):
|
|
1049
|
+
... return self.pad(x)
|
|
1050
|
+
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
|
|
1051
|
+
>>> pad = Net()
|
|
1052
|
+
>>> output = pad(x)
|
|
1053
|
+
>>> print(output)
|
|
1054
|
+
[[0. 0. 0. 0. 0. 0. 0.]
|
|
1055
|
+
[0. 0. 1. 2. 3. 0. 0.]
|
|
1056
|
+
[0. 0. 4. 5. 6. 0. 0.]
|
|
1057
|
+
[0. 0. 0. 0. 0. 0. 0.]]
|
|
1058
|
+
>>> # Another way to call
|
|
1059
|
+
>>> pad = ops.Pad(paddings=((1, 1), (2, 2)))
|
|
1060
|
+
>>> # From the above code, we can see following:
|
|
1061
|
+
>>> # "paddings=((1, 1), (2, 2))",
|
|
1062
|
+
>>> # paddings[0][0] = 1, indicates a row of values is filled top of the input data in the 1st dimension.
|
|
1063
|
+
>>> # Shown as follows:
|
|
1064
|
+
>>> # [[0. 0. 0.]
|
|
1065
|
+
>>> # [1. 2. 3.]
|
|
1066
|
+
>>> # [4. 5. 6.]]
|
|
1067
|
+
>>> # paddings[0][1] = 1 indicates a row of values is filled below input data in the 1st dimension.
|
|
1068
|
+
>>> # Shown as follows:
|
|
1069
|
+
>>> # [[0. 0. 0.]
|
|
1070
|
+
>>> # [1. 2. 3.]
|
|
1071
|
+
>>> # [4. 5. 6.]
|
|
1072
|
+
>>> # [0. 0. 0.]]
|
|
1073
|
+
>>> # paddings[1][0] = 2, indicates 2 rows of values is filled in front of input data in the 2nd dimension.
|
|
1074
|
+
>>> # Shown as follows:
|
|
1075
|
+
>>> # [[0. 0. 0. 0. 0.]
|
|
1076
|
+
>>> # [0. 0. 1. 2. 3.]
|
|
1077
|
+
>>> # [0. 0. 4. 5. 6.]
|
|
1078
|
+
>>> # [0. 0. 0. 0. 0.]]
|
|
1079
|
+
>>> # paddings[1][1] = 2, indicates 2 rows of values is filled in front of input data in the 2nd dimension.
|
|
1080
|
+
>>> # Shown as follows:
|
|
1081
|
+
>>> # [[0. 0. 0. 0. 0. 0. 0.]
|
|
1082
|
+
>>> # [0. 0. 1. 2. 3. 0. 0.]
|
|
1083
|
+
>>> # [0. 0. 4. 5. 6. 0. 0.]
|
|
1084
|
+
>>> # [0. 0. 0. 0. 0. 0. 0.]]
|
|
1085
|
+
>>> output = pad(x)
|
|
1086
|
+
>>> print(output)
|
|
1087
|
+
[[0. 0. 0. 0. 0. 0. 0.]
|
|
1088
|
+
[0. 0. 1. 2. 3. 0. 0.]
|
|
1089
|
+
[0. 0. 4. 5. 6. 0. 0.]
|
|
1090
|
+
[0. 0. 0. 0. 0. 0. 0.]]
|
|
1091
|
+
>>> # if mode is "REFLECT"
|
|
1092
|
+
>>> class Net(nn.Cell):
|
|
1093
|
+
... def __init__(self):
|
|
1094
|
+
... super(Net, self).__init__()
|
|
1095
|
+
... self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="REFLECT")
|
|
1096
|
+
... def construct(self, x):
|
|
1097
|
+
... return self.pad(x)
|
|
1098
|
+
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
|
|
1099
|
+
>>> pad = Net()
|
|
1100
|
+
>>> output = pad(x)
|
|
1101
|
+
>>> print(output)
|
|
1102
|
+
[[6. 5. 4. 5. 6. 5. 4.]
|
|
1103
|
+
[3. 2. 1. 2. 3. 2. 1.]
|
|
1104
|
+
[6. 5. 4. 5. 6. 5. 4.]
|
|
1105
|
+
[3. 2. 1. 2. 3. 2. 1.]]
|
|
1106
|
+
>>> # if mode is "SYMMETRIC"
|
|
1107
|
+
>>> class Net(nn.Cell):
|
|
1108
|
+
... def __init__(self):
|
|
1109
|
+
... super(Net, self).__init__()
|
|
1110
|
+
... self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="SYMMETRIC")
|
|
1111
|
+
... def construct(self, x):
|
|
1112
|
+
... return self.pad(x)
|
|
1113
|
+
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
|
|
1114
|
+
>>> pad = Net()
|
|
1115
|
+
>>> output = pad(x)
|
|
1116
|
+
>>> print(output)
|
|
1117
|
+
[[2. 1. 1. 2. 3. 3. 2.]
|
|
1118
|
+
[2. 1. 1. 2. 3. 3. 2.]
|
|
1119
|
+
[5. 4. 4. 5. 6. 6. 5.]
|
|
1120
|
+
[5. 4. 4. 5. 6. 6. 5.]]
|
|
1121
|
+
"""
|
|
1122
|
+
|
|
1123
|
+
def __init__(self, paddings, mode="CONSTANT"):
|
|
1124
|
+
"""Initialize Pad."""
|
|
1125
|
+
super(Pad, self).__init__()
|
|
1126
|
+
self.mode = mode
|
|
1127
|
+
self.paddings = paddings
|
|
1128
|
+
Validator.check_string(
|
|
1129
|
+
self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
|
1130
|
+
if not isinstance(paddings, tuple):
|
|
1131
|
+
raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, "
|
|
1132
|
+
f"but got {type(paddings).__name__}.")
|
|
1133
|
+
for item in paddings:
|
|
1134
|
+
if len(item) != 2:
|
|
1135
|
+
raise ValueError(f"For '{self.cls_name}', the dimension of 'paddings' must be (n, 2), "
|
|
1136
|
+
f"but got {paddings}.")
|
|
1137
|
+
if len(paddings) > 4:
|
|
1138
|
+
raise ValueError(f"For '{self.cls_name}', only 'paddings' up to 4 dims is supported, but got "
|
|
1139
|
+
f"{len(paddings)}.")
|
|
1140
|
+
if mode == "CONSTANT":
|
|
1141
|
+
self.pad = P.Pad(self.paddings)
|
|
1142
|
+
else:
|
|
1143
|
+
self.paddings = Tensor(np.array(self.paddings), dtype=mstype.int64)
|
|
1144
|
+
self.pad = P.MirrorPad(mode=mode)
|
|
1145
|
+
|
|
1146
|
+
def construct(self, x):
|
|
1147
|
+
if self.mode == "CONSTANT":
|
|
1148
|
+
x = self.pad(x)
|
|
1149
|
+
else:
|
|
1150
|
+
x = self.pad(x, self.paddings)
|
|
1151
|
+
return x
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
class Unfold(Cell):
|
|
1155
|
+
r"""
|
|
1156
|
+
Extracts patches from images.
|
|
1157
|
+
The input tensor must be a 4-D tensor and the data format is NCHW.
|
|
1158
|
+
|
|
1159
|
+
Args:
|
|
1160
|
+
ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
|
|
1161
|
+
and the format is [1, ksize_row, ksize_col, 1].
|
|
1162
|
+
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
|
|
1163
|
+
must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
|
|
1164
|
+
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
|
|
1165
|
+
pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
|
|
1166
|
+
padding (str): The type of padding algorithm, is a string whose value is ``"same"`` or ``"valid"`` , not case
|
|
1167
|
+
sensitive. Default: ``"valid"`` .
|
|
1168
|
+
|
|
1169
|
+
- ``"same"``: Means that the patch can take the part beyond the original image, and this part is filled
|
|
1170
|
+
with 0.
|
|
1171
|
+
|
|
1172
|
+
- ``"valid"``: Means that the taken patch area must be completely covered in the original image.
|
|
1173
|
+
|
|
1174
|
+
Inputs:
|
|
1175
|
+
- **x** (Tensor) - A 4-D tensor whose shape is :math:`[in\_batch, in\_depth, in\_row, in\_col]`
|
|
1176
|
+
and data type is number.
|
|
1177
|
+
|
|
1178
|
+
Outputs:
|
|
1179
|
+
Tensor, a 4-D tensor whose data type is same as `x`,
|
|
1180
|
+
and the shape is :math:`(out\_batch, out\_depth, out\_row, out\_col)`
|
|
1181
|
+
where `out_batch` is the same as the `in_batch`.
|
|
1182
|
+
|
|
1183
|
+
- :math:`out\_depth = ksize\_row * ksize\_col * in\_depth`
|
|
1184
|
+
- :math:`out\_row = (in\_row - (ksize\_row + (ksize\_row - 1) * (rate\_row - 1))) // stride\_row + 1`
|
|
1185
|
+
- :math:`out\_col = (in\_col - (ksize\_col + (ksize\_col - 1) * (rate\_col - 1))) // stride\_col + 1`
|
|
1186
|
+
|
|
1187
|
+
Raises:
|
|
1188
|
+
TypeError: If `ksizes`, `strides` or `rates` is neither a tuple nor list.
|
|
1189
|
+
ValueError: If shape of `ksizes`, `strides` or `rates` is not :math:`(1, x\_row, x\_col, 1)`.
|
|
1190
|
+
ValueError: If the second and third element of `ksizes`, `strides` or `rates` is less than 1.
|
|
1191
|
+
|
|
1192
|
+
Supported Platforms:
|
|
1193
|
+
``Ascend`` ``GPU``
|
|
1194
|
+
|
|
1195
|
+
Examples:
|
|
1196
|
+
>>> import mindspore
|
|
1197
|
+
>>> from mindspore import Tensor, nn
|
|
1198
|
+
>>> import numpy as np
|
|
1199
|
+
>>> net = nn.Unfold(ksizes=[1, 2, 2, 1], strides=[1, 2, 2, 1], rates=[1, 2, 2, 1])
|
|
1200
|
+
>>> # As stated in the above code:
|
|
1201
|
+
>>> # ksize_row = 2, ksize_col = 2, rate_row = 2, rate_col = 2, stride_row = 2, stride_col = 2.
|
|
1202
|
+
>>> image = Tensor(np.ones([2, 3, 6, 6]), dtype=mindspore.float16)
|
|
1203
|
+
>>> # in_batch = 2, in_depth = 3, in_row = 6, in_col = 6.
|
|
1204
|
+
>>> # Substituting the formula to get:
|
|
1205
|
+
>>> # out_batch = in_batch = 2
|
|
1206
|
+
>>> # out_depth = 2 * 2 * 3 = 12
|
|
1207
|
+
>>> # out_row = (6 - (2 + (2 - 1) * (2 - 1))) // 2 + 1 = 2
|
|
1208
|
+
>>> # out_col = (6 - (2 + (2 - 1) * (2 - 1))) // 2 + 1 = 2
|
|
1209
|
+
>>> output = net(image)
|
|
1210
|
+
>>> print(output.shape)
|
|
1211
|
+
(2, 12, 2, 2)
|
|
1212
|
+
"""
|
|
1213
|
+
|
|
1214
|
+
def __init__(self, ksizes, strides, rates, padding="valid"):
|
|
1215
|
+
"""Initialize Unfold."""
|
|
1216
|
+
super(Unfold, self).__init__()
|
|
1217
|
+
|
|
1218
|
+
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
|
1219
|
+
Validator.check_value_type(f"{arg_name}s", ksizes, [
|
|
1220
|
+
tuple, list], self.cls_name)
|
|
1221
|
+
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
|
1222
|
+
raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, "
|
|
1223
|
+
f"{arg_name}_col, 1], but got {arg_val}.")
|
|
1224
|
+
is_int = isinstance(arg_val[1], int) and isinstance(arg_val[2], int)
|
|
1225
|
+
if not is_int or arg_val[1] < 1 or arg_val[2] < 1:
|
|
1226
|
+
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in '{arg_name}s' must be "
|
|
1227
|
+
f"an positive integer number, but got {arg_name}_row is {arg_val[1]}, "
|
|
1228
|
+
f"{arg_name}_col is {arg_val[2]}")
|
|
1229
|
+
|
|
1230
|
+
_check_tuple_or_list("ksize", ksizes, self.cls_name)
|
|
1231
|
+
_check_tuple_or_list("stride", strides, self.cls_name)
|
|
1232
|
+
_check_tuple_or_list("rate", rates, self.cls_name)
|
|
1233
|
+
ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
|
|
1234
|
+
strides = strides[0], strides[3], strides[1], strides[2]
|
|
1235
|
+
rates = rates[0], rates[3], rates[1], rates[2]
|
|
1236
|
+
self.extract_image_patches = inner.ExtractImagePatches(
|
|
1237
|
+
ksizes, strides, rates, padding)
|
|
1238
|
+
|
|
1239
|
+
def construct(self, input_x):
|
|
1240
|
+
result = self.extract_image_patches(input_x)
|
|
1241
|
+
return result
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
class UnfoldExt(Cell):
|
|
1245
|
+
r"""
|
|
1246
|
+
Extracts sliding local blocks from a batched input tensor.
|
|
1247
|
+
|
|
1248
|
+
For details, please refer to :func:`mindspore.mint.nn.functional.unfold`.
|
|
1249
|
+
|
|
1250
|
+
Supported Platforms:
|
|
1251
|
+
``Ascend``
|
|
1252
|
+
|
|
1253
|
+
Examples:
|
|
1254
|
+
>>> import mindspore
|
|
1255
|
+
>>> import numpy as np
|
|
1256
|
+
>>> from mindspore import Tensor, nn
|
|
1257
|
+
>>> input = Tensor(np.random.rand(4, 4, 32, 32), mindspore.float64)
|
|
1258
|
+
>>> unfold = nn.UnfoldExt(kernel_size=3, dilation=1, stride=1)
|
|
1259
|
+
>>> output = unfold(input)
|
|
1260
|
+
>>> print(output.shape)
|
|
1261
|
+
(4, 36, 900)
|
|
1262
|
+
"""
|
|
1263
|
+
def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
|
|
1264
|
+
super(UnfoldExt, self).__init__()
|
|
1265
|
+
self.kernel_size = kernel_size
|
|
1266
|
+
self.dilation = dilation
|
|
1267
|
+
self.padding = padding
|
|
1268
|
+
self.stride = stride
|
|
1269
|
+
|
|
1270
|
+
def construct(self, input):
|
|
1271
|
+
return unfold_ext(input, self.kernel_size, self.dilation, self.padding, self.stride)
|
|
1272
|
+
|
|
1273
|
+
|
|
1274
|
+
class Fold(Cell):
|
|
1275
|
+
r"""
|
|
1276
|
+
Combines an array of sliding local blocks into a large containing tensor.
|
|
1277
|
+
|
|
1278
|
+
For details, please refer to :func:`mindspore.mint.nn.functional.fold`.
|
|
1279
|
+
|
|
1280
|
+
Supported Platforms:
|
|
1281
|
+
``Ascend``
|
|
1282
|
+
|
|
1283
|
+
Examples:
|
|
1284
|
+
>>> import numpy as np
|
|
1285
|
+
>>> from mindspore import Tensor, nn
|
|
1286
|
+
>>> from mindspore import dtype as mstype
|
|
1287
|
+
>>> fold = nn.Fold([8, 8], [2, 2], [2, 2], [2, 2], [2, 2])
|
|
1288
|
+
>>> input = Tensor(input_data=np.random.rand(16, 64, 25), dtype=mstype.float32)
|
|
1289
|
+
>>> output = fold(input)
|
|
1290
|
+
>>> print(output.shape)
|
|
1291
|
+
(16, 16, 8, 8)
|
|
1292
|
+
"""
|
|
1293
|
+
def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
|
1294
|
+
super(Fold, self).__init__()
|
|
1295
|
+
self.output_size = output_size
|
|
1296
|
+
self.kernel_size = kernel_size
|
|
1297
|
+
self.dilation = dilation
|
|
1298
|
+
self.padding = padding
|
|
1299
|
+
self.stride = stride
|
|
1300
|
+
|
|
1301
|
+
def construct(self, input):
|
|
1302
|
+
return fold_ext(input, self.output_size, self.kernel_size,
|
|
1303
|
+
self.dilation, self.padding, self.stride)
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
@_primexpr
|
|
1307
|
+
def tril(x_shape, x_dtype, k):
|
|
1308
|
+
Validator.check_int(len(x_shape), 1, Validator.GE, "x rank", "tril")
|
|
1309
|
+
Validator.check_is_int(k, "k value", "tril")
|
|
1310
|
+
value = F.cast(P.Tril(diagonal=k)(F.ones(x_shape, x_dtype)), x_dtype)
|
|
1311
|
+
return value
|
|
1312
|
+
|
|
1313
|
+
|
|
1314
|
+
class Tril(Cell):
|
|
1315
|
+
"""
|
|
1316
|
+
The Tril class will be deprecated in the future,
|
|
1317
|
+
this function can be replaced by :func:`ops.tril`
|
|
1318
|
+
"""
|
|
1319
|
+
|
|
1320
|
+
@deprecated("2.0", "ops.tril", False)
|
|
1321
|
+
def __init__(self):
|
|
1322
|
+
"""Initialize Tril."""
|
|
1323
|
+
super(Tril, self).__init__()
|
|
1324
|
+
self.dtype = P.DType()
|
|
1325
|
+
self.mul = P.Mul()
|
|
1326
|
+
self.cast = P.Cast()
|
|
1327
|
+
|
|
1328
|
+
def construct(self, x, k=0):
|
|
1329
|
+
assist = tril(x.shape, self.dtype(x), k)
|
|
1330
|
+
result = self.mul(self.cast(x, mstype.float32),
|
|
1331
|
+
self.cast(assist, mstype.float32))
|
|
1332
|
+
return self.cast(result, self.dtype(x))
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
@_primexpr
|
|
1336
|
+
def triu(x_shape, x_dtype, k):
|
|
1337
|
+
Validator.check_int(len(x_shape), 1, Validator.GE, "x rank", "triu")
|
|
1338
|
+
Validator.check_is_int(k, "k value", "triu")
|
|
1339
|
+
value = F.cast(P.Triu(k)(F.ones(x_shape, x_dtype)), x_dtype)
|
|
1340
|
+
return value
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
class Triu(Cell):
|
|
1344
|
+
"""
|
|
1345
|
+
The Triu class will be deprecated in the future,
|
|
1346
|
+
this function can be replaced by :func:`ops.triu`
|
|
1347
|
+
"""
|
|
1348
|
+
|
|
1349
|
+
@deprecated("2.0", "ops.triu", False)
|
|
1350
|
+
def __init__(self):
|
|
1351
|
+
"""Initialize Triu."""
|
|
1352
|
+
super(Triu, self).__init__()
|
|
1353
|
+
self.dtype = P.DType()
|
|
1354
|
+
self.mul = P.Mul()
|
|
1355
|
+
self.cast = P.Cast()
|
|
1356
|
+
|
|
1357
|
+
def construct(self, x, k=0):
|
|
1358
|
+
assist = triu(x.shape, self.dtype(x), k)
|
|
1359
|
+
result = self.mul(self.cast(x, mstype.float32),
|
|
1360
|
+
self.cast(assist, mstype.float32))
|
|
1361
|
+
return self.cast(result, self.dtype(x))
|
|
1362
|
+
|
|
1363
|
+
|
|
1364
|
+
@_primexpr
|
|
1365
|
+
def _get_matrix_diag_assist(x_shape, x_dtype):
|
|
1366
|
+
"""Get matrix diag assist"""
|
|
1367
|
+
Validator.check_int(len(x_shape), 1, Validator.GE, "x rank", "_get_matrix_diag_assist")
|
|
1368
|
+
base_eye = F.reshape(
|
|
1369
|
+
F.eye(x_shape[-1], x_shape[-1], x_dtype), (x_shape[-1] * x_shape[-1],))
|
|
1370
|
+
if len(x_shape) == 1:
|
|
1371
|
+
assist = F.reshape(base_eye, x_shape + (x_shape[-1],))
|
|
1372
|
+
else:
|
|
1373
|
+
assist = F.reshape(
|
|
1374
|
+
F.tile(base_eye, x_shape[:-1]), x_shape + (x_shape[-1],))
|
|
1375
|
+
value = F.cast(assist, x_dtype)
|
|
1376
|
+
return value
|
|
1377
|
+
|
|
1378
|
+
|
|
1379
|
+
@constexpr
|
|
1380
|
+
def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
|
1381
|
+
"""Get matrix diag part assist"""
|
|
1382
|
+
Validator.check_int(len(x_shape), 2, Validator.GE, "x rank", "_get_matrix_diag_part_assist")
|
|
1383
|
+
base_eye = F.reshape(
|
|
1384
|
+
F.eye(x_shape[-2], x_shape[-1], x_dtype), (x_shape[-2] * x_shape[-1],))
|
|
1385
|
+
if len(x_shape) <= 2:
|
|
1386
|
+
assist = F.reshape(base_eye, x_shape)
|
|
1387
|
+
else:
|
|
1388
|
+
assist = F.reshape(F.tile(base_eye, x_shape[:-2]), x_shape)
|
|
1389
|
+
value = F.cast(assist, x_dtype)
|
|
1390
|
+
return value
|
|
1391
|
+
|
|
1392
|
+
|
|
1393
|
+
class MatrixDiag(Cell):
|
|
1394
|
+
r"""
|
|
1395
|
+
The MatrixDiag class will be deprecated in the future,
|
|
1396
|
+
this function can be replaced by :func:`ops.diag`
|
|
1397
|
+
"""
|
|
1398
|
+
|
|
1399
|
+
@deprecated("2.0", "ops.diag", False)
|
|
1400
|
+
def __init__(self):
|
|
1401
|
+
"""Initialize MatrixDiag."""
|
|
1402
|
+
super(MatrixDiag, self).__init__()
|
|
1403
|
+
self.matrix_diag = inner.MatrixDiag()
|
|
1404
|
+
self.dtype = P.DType()
|
|
1405
|
+
|
|
1406
|
+
def construct(self, input_x):
|
|
1407
|
+
x_shape = F.shape(input_x)
|
|
1408
|
+
x_dtype = self.dtype(input_x)
|
|
1409
|
+
assist = _get_matrix_diag_assist(x_shape, x_dtype)
|
|
1410
|
+
out_matrix_diag = self.matrix_diag(input_x, assist)
|
|
1411
|
+
return out_matrix_diag
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
class MatrixDiagPart(Cell):
|
|
1415
|
+
r"""
|
|
1416
|
+
The MatrixDiagPart class will be deprecated in the future,
|
|
1417
|
+
this function can be replaced by :func:`ops.diagonal`
|
|
1418
|
+
"""
|
|
1419
|
+
|
|
1420
|
+
@deprecated("2.0", "ops.diagonal", False)
|
|
1421
|
+
def __init__(self):
|
|
1422
|
+
"""Initialize MatrixDiagPart."""
|
|
1423
|
+
super(MatrixDiagPart, self).__init__()
|
|
1424
|
+
self.matrix_diag_part = inner.MatrixDiagPart()
|
|
1425
|
+
self.dtype = P.DType()
|
|
1426
|
+
|
|
1427
|
+
def construct(self, input_x):
|
|
1428
|
+
x_shape = F.shape(input_x)
|
|
1429
|
+
x_dtype = self.dtype(input_x)
|
|
1430
|
+
assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
|
|
1431
|
+
out_matrix_diag_part = self.matrix_diag_part(input_x, assist)
|
|
1432
|
+
return out_matrix_diag_part
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
class MatrixSetDiag(Cell):
|
|
1436
|
+
r"""
|
|
1437
|
+
Modifies the batched diagonal part of a batched tensor.
|
|
1438
|
+
|
|
1439
|
+
Assume `x` has :math:`k+1` dimensions :math:`[I, J, K, ..., M, N]` and `diagonal` has :math:`k`
|
|
1440
|
+
dimensions :math:`[I, J, K, ..., min(M, N)]`, the output is a tensor of rank :math:`k+1` with dimensions
|
|
1441
|
+
:math:`[I, J, K, ..., M, N]`, where:
|
|
1442
|
+
|
|
1443
|
+
.. math::
|
|
1444
|
+
output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]\ for\ m == n
|
|
1445
|
+
|
|
1446
|
+
.. math::
|
|
1447
|
+
output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n
|
|
1448
|
+
|
|
1449
|
+
Inputs:
|
|
1450
|
+
- **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
|
|
1451
|
+
float32, float16, int32, int8, and uint8.
|
|
1452
|
+
- **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
|
|
1453
|
+
|
|
1454
|
+
Outputs:
|
|
1455
|
+
Tensor, has the same type and shape as input `x`.
|
|
1456
|
+
|
|
1457
|
+
Raises:
|
|
1458
|
+
TypeError: If dtype of `x` or `diagonal` is not one of float32, float16, int32, int8 or uint8.
|
|
1459
|
+
ValueError: If length of shape of `x` is less than 2.
|
|
1460
|
+
ValueError: If x_shape[-2] < x_shape[-1] and x_shape[:-1] != diagonal_shape.
|
|
1461
|
+
ValueError: If x_shape[-2] >= x_shape[-1] and x_shape[:-2] + x_shape[-1:] != diagonal_shape.
|
|
1462
|
+
|
|
1463
|
+
Supported Platforms:
|
|
1464
|
+
``Ascend``
|
|
1465
|
+
|
|
1466
|
+
Examples:
|
|
1467
|
+
>>> import mindspore
|
|
1468
|
+
>>> from mindspore import Tensor, nn
|
|
1469
|
+
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
|
1470
|
+
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
|
|
1471
|
+
>>> matrix_set_diag = nn.MatrixSetDiag()
|
|
1472
|
+
>>> output = matrix_set_diag(x, diagonal)
|
|
1473
|
+
>>> print(output)
|
|
1474
|
+
[[[-1. 0.]
|
|
1475
|
+
[ 0. 2.]]
|
|
1476
|
+
[[-1. 0.]
|
|
1477
|
+
[ 0. 1.]]
|
|
1478
|
+
[[-1. 0.]
|
|
1479
|
+
[ 0. 1.]]]
|
|
1480
|
+
"""
|
|
1481
|
+
|
|
1482
|
+
def __init__(self):
|
|
1483
|
+
"""Initialize MatrixSetDiag."""
|
|
1484
|
+
super(MatrixSetDiag, self).__init__()
|
|
1485
|
+
self.matrix_set_diag = inner.MatrixSetDiag()
|
|
1486
|
+
self.dtype = P.DType()
|
|
1487
|
+
|
|
1488
|
+
def construct(self, input_x, diagonal):
|
|
1489
|
+
x_shape = F.shape(input_x)
|
|
1490
|
+
x_dtype = self.dtype(input_x)
|
|
1491
|
+
assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
|
|
1492
|
+
out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist)
|
|
1493
|
+
return out_matrix_set_diag
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
@constexpr
|
|
1497
|
+
def _check_input_dim(axis, dim, cls_name):
|
|
1498
|
+
Validator.check_int_range(axis, -dim, dim, Validator.INC_LEFT, 'axis', cls_name)
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
class Roll(Cell):
|
|
1502
|
+
"""
|
|
1503
|
+
The Roll class will be deprecated in the future,
|
|
1504
|
+
this function can be replaced by :func:`ops.roll`
|
|
1505
|
+
"""
|
|
1506
|
+
|
|
1507
|
+
@deprecated("2.0", "ops.roll", False)
|
|
1508
|
+
def __init__(self, shift, axis):
|
|
1509
|
+
"""Initialize Roll"""
|
|
1510
|
+
super(Roll, self).__init__()
|
|
1511
|
+
Validator.check_value_type(
|
|
1512
|
+
"shift", shift, [int, tuple, list], self.cls_name)
|
|
1513
|
+
Validator.check_value_type(
|
|
1514
|
+
"axis", axis, [int, tuple, list], self.cls_name)
|
|
1515
|
+
self.shape_op = P.Shape()
|
|
1516
|
+
self.shift = shift
|
|
1517
|
+
self.axis = axis
|
|
1518
|
+
self.op_list = []
|
|
1519
|
+
self.gpu = False
|
|
1520
|
+
|
|
1521
|
+
if not isinstance(self.axis, (list, tuple)):
|
|
1522
|
+
self.axis = [self.axis]
|
|
1523
|
+
if not isinstance(self.shift, (list, tuple)):
|
|
1524
|
+
self.shift = [self.shift]
|
|
1525
|
+
if context.get_context("device_target") == "GPU":
|
|
1526
|
+
Validator.check_int(len(self.shift), 1, Validator.GE, "shift", "Roll")
|
|
1527
|
+
Validator.check_int(len(self.axis), 1, Validator.GE, "axis", "Roll")
|
|
1528
|
+
for s_axis in self.axis:
|
|
1529
|
+
Validator.check_is_int(s_axis, "axis", "Roll")
|
|
1530
|
+
for s_shift in self.shift:
|
|
1531
|
+
Validator.check_is_int(s_shift, "shift", "Roll")
|
|
1532
|
+
self.roll = P.Roll(self.shift, self.axis)
|
|
1533
|
+
self.gpu = True
|
|
1534
|
+
if len(self.shift) != len(self.axis):
|
|
1535
|
+
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
|
|
1536
|
+
f"the same, but got the length of 'shift' {len(self.shift)} "
|
|
1537
|
+
f"and the length of 'axis' {len(self.axis)}.")
|
|
1538
|
+
else:
|
|
1539
|
+
if not isinstance(self.axis, (list, tuple)):
|
|
1540
|
+
self.op_list.append(
|
|
1541
|
+
(P.Roll(shift=self.shift, axis=0), self.axis))
|
|
1542
|
+
else:
|
|
1543
|
+
if len(self.shift) != len(self.axis):
|
|
1544
|
+
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
|
|
1545
|
+
f"the same, but got the length of 'shift' {len(self.shift)} "
|
|
1546
|
+
f"and the length of 'axis' {len(self.axis)}.")
|
|
1547
|
+
for idx, _ in enumerate(self.axis):
|
|
1548
|
+
self.op_list.append(
|
|
1549
|
+
(P.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
|
|
1550
|
+
|
|
1551
|
+
def construct(self, input_x):
|
|
1552
|
+
dim = len(self.shape_op(input_x))
|
|
1553
|
+
if self.gpu:
|
|
1554
|
+
output = self.roll(input_x)
|
|
1555
|
+
else:
|
|
1556
|
+
for single_op_roll, single_axis in self.op_list:
|
|
1557
|
+
_check_input_dim(single_axis, dim, self.cls_name)
|
|
1558
|
+
if single_axis < 0:
|
|
1559
|
+
single_axis += dim
|
|
1560
|
+
transpose_perm = []
|
|
1561
|
+
for i in range(dim):
|
|
1562
|
+
transpose_perm.append(i)
|
|
1563
|
+
transpose_perm[0], transpose_perm[single_axis] = single_axis, 0
|
|
1564
|
+
|
|
1565
|
+
input_x = input_x.transpose(transpose_perm)
|
|
1566
|
+
input_x = single_op_roll(input_x)
|
|
1567
|
+
input_x = input_x.transpose(transpose_perm)
|
|
1568
|
+
output = input_x
|
|
1569
|
+
return output
|
|
1570
|
+
|
|
1571
|
+
|
|
1572
|
+
class Unflatten(Cell):
|
|
1573
|
+
r"""
|
|
1574
|
+
Unflattens a Tensor dim according to `axis` and `unflattened_size`.
|
|
1575
|
+
|
|
1576
|
+
Args:
|
|
1577
|
+
axis (int): specifies the dimension of the input Tensor to be unflattened.
|
|
1578
|
+
unflattened_size (Union(tuple[int], list[int])): the new shape of the unflattened dimension of
|
|
1579
|
+
the Tensor and it can be a tuple of ints or a list of ints. The product of `unflattened_size`
|
|
1580
|
+
must equal to input_shape[axis].
|
|
1581
|
+
|
|
1582
|
+
Inputs:
|
|
1583
|
+
- **input** (Tensor) - The input Tensor to be unflattened.
|
|
1584
|
+
|
|
1585
|
+
Outputs:
|
|
1586
|
+
Tensor that has been unflattend.
|
|
1587
|
+
|
|
1588
|
+
Raises:
|
|
1589
|
+
TypeError: If `axis` is not int.
|
|
1590
|
+
TypeError: If `unflattened_size` is neither tuple of ints nor list of ints.
|
|
1591
|
+
TypeError: The product of `unflattened_size` does not equal to input_shape[axis].
|
|
1592
|
+
|
|
1593
|
+
Supported Platforms:
|
|
1594
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1595
|
+
|
|
1596
|
+
Examples:
|
|
1597
|
+
>>> import mindspore
|
|
1598
|
+
>>> from mindspore import Tensor, nn
|
|
1599
|
+
>>> import numpy as np
|
|
1600
|
+
>>> input = Tensor(np.arange(0, 100).reshape(2, 10, 5), mindspore.float32)
|
|
1601
|
+
>>> net = nn.Unflatten(1, (2, 5))
|
|
1602
|
+
>>> output = net(input)
|
|
1603
|
+
>>> print(f"before unflatten the input shape is {input.shape}")
|
|
1604
|
+
before unflatten the input shape is (2, 10, 5)
|
|
1605
|
+
>>> print(f"after unflatten the output shape is {output.shape}")
|
|
1606
|
+
after unflatten the output shape is (2, 2, 5, 5)
|
|
1607
|
+
"""
|
|
1608
|
+
|
|
1609
|
+
def __init__(self, axis, unflattened_size):
|
|
1610
|
+
"""Initialize Unflatten."""
|
|
1611
|
+
super(Unflatten, self).__init__()
|
|
1612
|
+
self.shape = P.Shape()
|
|
1613
|
+
self.reshape = P.Reshape()
|
|
1614
|
+
Validator.check_is_int(axis, 'axis', 'Unflatten')
|
|
1615
|
+
Validator.check_value_type(
|
|
1616
|
+
'unflattended_size', unflattened_size, (list, tuple), 'Unflatten')
|
|
1617
|
+
self.axis = axis
|
|
1618
|
+
if isinstance(unflattened_size, list):
|
|
1619
|
+
unflattened_size = tuple(unflattened_size)
|
|
1620
|
+
self.unflattened_size = unflattened_size
|
|
1621
|
+
|
|
1622
|
+
def construct(self, input_x):
|
|
1623
|
+
input_shape = self.shape(input_x)
|
|
1624
|
+
new_shape = tuple()
|
|
1625
|
+
new_shape += input_shape[: self.axis]
|
|
1626
|
+
new_shape += self.unflattened_size
|
|
1627
|
+
if self.axis != -1:
|
|
1628
|
+
new_shape += input_shape[self.axis + 1:]
|
|
1629
|
+
return self.reshape(input_x, new_shape)
|