mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
mindspore/_checkparam.py
ADDED
|
@@ -0,0 +1,1419 @@
|
|
|
1
|
+
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Check parameters."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
import re
|
|
19
|
+
import inspect
|
|
20
|
+
import math
|
|
21
|
+
from types import FunctionType, MethodType
|
|
22
|
+
from functools import reduce, wraps
|
|
23
|
+
from itertools import repeat
|
|
24
|
+
from collections.abc import Iterable
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from mindspore import context
|
|
28
|
+
from mindspore import log as logger
|
|
29
|
+
from mindspore.common import dtype as mstype
|
|
30
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
31
|
+
|
|
32
|
+
EQ = 1 # ==
|
|
33
|
+
NE = 2 # !=
|
|
34
|
+
LT = 3 # <
|
|
35
|
+
LE = 4 # <=
|
|
36
|
+
GT = 5 # >
|
|
37
|
+
GE = 6 # >=
|
|
38
|
+
# scalar range check
|
|
39
|
+
INC_NEITHER = 7 # (), include neither
|
|
40
|
+
INC_LEFT = 8 # [), include left
|
|
41
|
+
INC_RIGHT = 9 # (], include right
|
|
42
|
+
INC_BOTH = 10 # [], include both
|
|
43
|
+
# collection in, not in
|
|
44
|
+
IN = 11
|
|
45
|
+
NOT_IN = 12
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _check_binary_rel(val1, val2, rel):
|
|
49
|
+
"""check binary relation"""
|
|
50
|
+
if rel == EQ:
|
|
51
|
+
return val1 == val2
|
|
52
|
+
if rel == NE:
|
|
53
|
+
return val1 != val2
|
|
54
|
+
if rel == LT:
|
|
55
|
+
return val1 < val2
|
|
56
|
+
if rel == LE:
|
|
57
|
+
return val1 <= val2
|
|
58
|
+
if rel == GT:
|
|
59
|
+
return val1 > val2
|
|
60
|
+
if rel == GE:
|
|
61
|
+
return val1 >= val2
|
|
62
|
+
if rel == IN:
|
|
63
|
+
return val1 in val2
|
|
64
|
+
if rel == NOT_IN:
|
|
65
|
+
return val1 not in val2
|
|
66
|
+
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _check_inc_rel(val, lower, upper, rel):
|
|
71
|
+
"""check include relation"""
|
|
72
|
+
if rel == INC_NEITHER:
|
|
73
|
+
return not (val <= lower or val >= upper)
|
|
74
|
+
if rel == INC_LEFT:
|
|
75
|
+
return not (val < lower or val >= upper)
|
|
76
|
+
if rel == INC_RIGHT:
|
|
77
|
+
return not (val <= lower or val > upper)
|
|
78
|
+
if rel == INC_BOTH:
|
|
79
|
+
return not (val < lower or val > upper)
|
|
80
|
+
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _format_str_one_value(value, rel):
|
|
85
|
+
"""format string"""
|
|
86
|
+
if rel == EQ:
|
|
87
|
+
return f"= {value}"
|
|
88
|
+
if rel == NE:
|
|
89
|
+
return f"!= {value}"
|
|
90
|
+
if rel == LT:
|
|
91
|
+
return f"< {value}"
|
|
92
|
+
if rel == LE:
|
|
93
|
+
return f"<= {value}"
|
|
94
|
+
if rel == GT:
|
|
95
|
+
return f"> {value}"
|
|
96
|
+
if rel == GE:
|
|
97
|
+
return f">= {value}"
|
|
98
|
+
if rel == IN:
|
|
99
|
+
return f"in {value}"
|
|
100
|
+
if rel == NOT_IN:
|
|
101
|
+
return f"not in {value}"
|
|
102
|
+
|
|
103
|
+
return ""
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _format_str_two_value(val1, val2, rel):
|
|
107
|
+
"""format string"""
|
|
108
|
+
if rel == INC_NEITHER:
|
|
109
|
+
return f"({val1}, {val2})"
|
|
110
|
+
if rel == INC_LEFT:
|
|
111
|
+
return f"[{val1}, {val2})"
|
|
112
|
+
if rel == INC_RIGHT:
|
|
113
|
+
return f"({val1}, {val2}]"
|
|
114
|
+
if rel == INC_BOTH:
|
|
115
|
+
return f"[{val1}, {val2}]"
|
|
116
|
+
|
|
117
|
+
return ""
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
|
|
121
|
+
greater_zero=True, third_one=False, three_input=False):
|
|
122
|
+
"""
|
|
123
|
+
Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def _raise_message(third_one_flag=False, three_input_flag=False):
|
|
127
|
+
if third_one_flag:
|
|
128
|
+
raise ValueError(f"For '{prim_name}', the depth of parameter '{arg_name}' must be 1, " \
|
|
129
|
+
f"but got {ret_value[-3]}.")
|
|
130
|
+
if three_input_flag:
|
|
131
|
+
raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer " \
|
|
132
|
+
f"or a tuple of three positive integer, but got {arg_value}.")
|
|
133
|
+
raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer or " \
|
|
134
|
+
f"a tuple of three {'or five ' if allow_five else ''}positive integer, but got {arg_value}")
|
|
135
|
+
|
|
136
|
+
def _get_return_value():
|
|
137
|
+
def _check():
|
|
138
|
+
if not isinstance(arg_value, int):
|
|
139
|
+
if len(arg_value) == 5:
|
|
140
|
+
if not allow_five:
|
|
141
|
+
_raise_message()
|
|
142
|
+
elif not len(arg_value) == 3:
|
|
143
|
+
_raise_message()
|
|
144
|
+
|
|
145
|
+
_check()
|
|
146
|
+
if isinstance(arg_value, int):
|
|
147
|
+
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
|
|
148
|
+
elif len(arg_value) == 3:
|
|
149
|
+
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
|
|
150
|
+
else: # case: len(arg_value) == 5
|
|
151
|
+
ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
|
|
152
|
+
|
|
153
|
+
return ret
|
|
154
|
+
|
|
155
|
+
def _check_value(ret_value):
|
|
156
|
+
for item in ret_value:
|
|
157
|
+
if isinstance(item, int) and not isinstance(item, bool):
|
|
158
|
+
if greater_zero and item > 0:
|
|
159
|
+
continue
|
|
160
|
+
if not greater_zero and item >= 0:
|
|
161
|
+
continue
|
|
162
|
+
_raise_message()
|
|
163
|
+
|
|
164
|
+
def _check_third_one(ret_value):
|
|
165
|
+
if third_one:
|
|
166
|
+
if ret_value[-3] != 1:
|
|
167
|
+
_raise_message(third_one_flag=third_one)
|
|
168
|
+
|
|
169
|
+
check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
|
170
|
+
if three_input and isinstance(arg_value, tuple):
|
|
171
|
+
if len(arg_value) != 3:
|
|
172
|
+
_raise_message(three_input_flag=three_input)
|
|
173
|
+
ret_value = _get_return_value()
|
|
174
|
+
_check_value(ret_value)
|
|
175
|
+
_check_third_one(ret_value)
|
|
176
|
+
|
|
177
|
+
return tuple(ret_value)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _check_dup(axes):
|
|
181
|
+
for item in axes:
|
|
182
|
+
count = 0
|
|
183
|
+
for item2 in axes:
|
|
184
|
+
if item == item2:
|
|
185
|
+
count += 1
|
|
186
|
+
|
|
187
|
+
if count > 1:
|
|
188
|
+
raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
|
|
192
|
+
"""
|
|
193
|
+
Check argument integer.
|
|
194
|
+
|
|
195
|
+
Usage:
|
|
196
|
+
- arg_value = _check_number(arg_value, 2, GT, int, "value", None)
|
|
197
|
+
"""
|
|
198
|
+
prim_name = f"For \'{prim_name}\', the " if prim_name else 'The '
|
|
199
|
+
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
200
|
+
|
|
201
|
+
def _check_param():
|
|
202
|
+
prim_info = f'{prim_name}' + f'{arg_name}'
|
|
203
|
+
if isinstance(arg_value, arg_type):
|
|
204
|
+
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
|
205
|
+
raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
|
|
206
|
+
else:
|
|
207
|
+
raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
|
|
208
|
+
|
|
209
|
+
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
|
|
210
|
+
rel_ret = _check_binary_rel(arg_value, value, rel)
|
|
211
|
+
if type_mismatch or not rel_ret:
|
|
212
|
+
rel_str = _format_str_one_value(value, rel)
|
|
213
|
+
msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \
|
|
214
|
+
f"but got '{arg_value}' with type '{type(arg_value).__name__}'."
|
|
215
|
+
if type_mismatch:
|
|
216
|
+
raise TypeError(msg)
|
|
217
|
+
raise ValueError(msg)
|
|
218
|
+
|
|
219
|
+
_check_param()
|
|
220
|
+
return arg_value
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|
224
|
+
"""
|
|
225
|
+
Checks input value is float type or not.
|
|
226
|
+
|
|
227
|
+
Usage:
|
|
228
|
+
- number = check_is_number(number, int)
|
|
229
|
+
- number = check_is_number(number, int, "bias")
|
|
230
|
+
- number = check_is_number(number, int, "bias", "bias_class")
|
|
231
|
+
"""
|
|
232
|
+
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
|
|
233
|
+
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
234
|
+
|
|
235
|
+
def _check_param():
|
|
236
|
+
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
|
|
237
|
+
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
|
238
|
+
raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
|
|
239
|
+
else:
|
|
240
|
+
raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \
|
|
241
|
+
f"but got '{type(arg_value).__name__}'.")
|
|
242
|
+
|
|
243
|
+
_check_param()
|
|
244
|
+
return arg_value
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
|
248
|
+
"""
|
|
249
|
+
Method for checking whether an int value is in some range.
|
|
250
|
+
|
|
251
|
+
Usage:
|
|
252
|
+
- number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0]
|
|
253
|
+
- number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1]
|
|
254
|
+
"""
|
|
255
|
+
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
|
|
256
|
+
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
257
|
+
|
|
258
|
+
def _check_param():
|
|
259
|
+
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
|
260
|
+
if type_mismatch:
|
|
261
|
+
raise TypeError(f"{prim_name} {arg_name} must be '{value_type.__name__}', " \
|
|
262
|
+
f"but got '{type(arg_value).__name__}'.")
|
|
263
|
+
|
|
264
|
+
if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel):
|
|
265
|
+
rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
|
|
266
|
+
raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \
|
|
267
|
+
f"but got {arg_value} with type '{type(arg_value).__name__}'.")
|
|
268
|
+
|
|
269
|
+
_check_param()
|
|
270
|
+
return arg_value
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError):
|
|
274
|
+
"""
|
|
275
|
+
Method for judging relation between two int values or list/tuple made up of ints.
|
|
276
|
+
This method is not suitable for judging relation between floats, since it does not consider float error.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def _check():
|
|
280
|
+
if not _check_binary_rel(arg_value, value, rel):
|
|
281
|
+
rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
|
|
282
|
+
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
283
|
+
msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}"
|
|
284
|
+
raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.')
|
|
285
|
+
|
|
286
|
+
_check()
|
|
287
|
+
return arg_value
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
|
|
291
|
+
"""
|
|
292
|
+
Checks input integer value `arg_value` compare to `value`.
|
|
293
|
+
|
|
294
|
+
Usage:
|
|
295
|
+
- number = check_int(number, 0, GE, "number", None) # number >= 0
|
|
296
|
+
"""
|
|
297
|
+
return _check_number(arg_value, value, rel, int, arg_name, prim_name)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def check_is_int(arg_value, arg_name=None, prim_name=None):
|
|
301
|
+
"""
|
|
302
|
+
Checks input value is float type or not.
|
|
303
|
+
|
|
304
|
+
Usage:
|
|
305
|
+
- number = check_is_int(number, int)
|
|
306
|
+
- number = check_is_int(number, int, "bias")
|
|
307
|
+
- number = check_is_int(number, int, "bias", "bias_class")
|
|
308
|
+
"""
|
|
309
|
+
return check_is_number(arg_value, int, arg_name, prim_name)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
|
|
313
|
+
"""
|
|
314
|
+
Checks input integer value `arg_value` compare to `value`.
|
|
315
|
+
|
|
316
|
+
Usage:
|
|
317
|
+
- number = check_equal_int(number, 0, "number", None) # number == 0
|
|
318
|
+
"""
|
|
319
|
+
return _check_number(arg_value, value, EQ, int, arg_name, prim_name)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def check_positive_int(arg_value, arg_name=None, prim_name=None):
|
|
323
|
+
"""
|
|
324
|
+
Check argument is positive integer, which mean arg_value > 0.
|
|
325
|
+
|
|
326
|
+
Usage:
|
|
327
|
+
- number = check_positive_int(number)
|
|
328
|
+
- number = check_positive_int(number, "bias")
|
|
329
|
+
"""
|
|
330
|
+
return _check_number(arg_value, 0, GT, int, arg_name, prim_name)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
|
|
334
|
+
"""
|
|
335
|
+
Check argument is positive int sequence, which mean all element > 0 in sequence.
|
|
336
|
+
|
|
337
|
+
Usage:
|
|
338
|
+
- sequence = check_positive_int_sequence(sequence)
|
|
339
|
+
- sequence = check_positive_int_sequence(sequence, "dims")
|
|
340
|
+
"""
|
|
341
|
+
for idx in range(len(sequence)):
|
|
342
|
+
element = sequence[idx]
|
|
343
|
+
arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
|
|
344
|
+
_check_number(element, 0, GT, int, arg_idx, prim_name)
|
|
345
|
+
return sequence
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def check_negative_int(arg_value, arg_name=None, prim_name=None):
|
|
349
|
+
"""
|
|
350
|
+
Check argument is negative integer, which mean arg_value < 0.
|
|
351
|
+
|
|
352
|
+
Usage:
|
|
353
|
+
- number = check_negative_int(number)
|
|
354
|
+
- number = check_negative_int(number, "bias")
|
|
355
|
+
"""
|
|
356
|
+
return _check_number(arg_value, 0, LT, int, arg_name, prim_name)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
|
|
360
|
+
"""
|
|
361
|
+
Check argument is non-negative integer, which mean arg_value <= 0.
|
|
362
|
+
|
|
363
|
+
Usage:
|
|
364
|
+
- number = check_non_positive_int(number)
|
|
365
|
+
- number = check_non_positive_int(number, "bias")
|
|
366
|
+
"""
|
|
367
|
+
return _check_number(arg_value, 0, LE, int, arg_name, prim_name)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
|
|
371
|
+
"""
|
|
372
|
+
Check argument is non-negative integer, which mean arg_value >= 0.
|
|
373
|
+
|
|
374
|
+
Usage:
|
|
375
|
+
- number = check_non_negative_int(number)
|
|
376
|
+
- number = check_non_negative_int(number, "bias")
|
|
377
|
+
"""
|
|
378
|
+
return _check_number(arg_value, 0, GE, int, arg_name, prim_name)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
|
|
382
|
+
"""
|
|
383
|
+
Check argument is positive sequence, which mean all element >= 0 in sequence.
|
|
384
|
+
|
|
385
|
+
Usage:
|
|
386
|
+
- sequence = check_non_negative_int_sequence(sequence)
|
|
387
|
+
- sequence = check_non_negative_int_sequence(sequence, "dims")
|
|
388
|
+
"""
|
|
389
|
+
for idx in range(len(sequence)):
|
|
390
|
+
element = sequence[idx]
|
|
391
|
+
arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
|
|
392
|
+
_check_number(element, 0, GE, int, arg_idx, prim_name)
|
|
393
|
+
return sequence
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
|
|
397
|
+
"""
|
|
398
|
+
Checks input float value `arg_value` compare to `value`.
|
|
399
|
+
|
|
400
|
+
Usage:
|
|
401
|
+
- number = check_float(number, 0.0, GE, "number", None) # number >= 0
|
|
402
|
+
"""
|
|
403
|
+
return _check_number(arg_value, value, rel, float, arg_name, prim_name)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def check_is_float(arg_value, arg_name=None, prim_name=None):
|
|
407
|
+
"""
|
|
408
|
+
Checks input value is float type or not.
|
|
409
|
+
|
|
410
|
+
Usage:
|
|
411
|
+
- number = check_is_float(number)
|
|
412
|
+
- number = check_is_float(number, "bias")
|
|
413
|
+
- number = check_is_float(number, "bias", "bias_class")
|
|
414
|
+
"""
|
|
415
|
+
return check_is_number(arg_value, float, arg_name, prim_name)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def check_positive_float(arg_value, arg_name=None, prim_name=None):
|
|
419
|
+
"""
|
|
420
|
+
Check argument is positive float, which mean arg_value > 0.
|
|
421
|
+
|
|
422
|
+
Usage:
|
|
423
|
+
- number = check_positive_float(number)
|
|
424
|
+
- number = check_positive_float(number, "bias")
|
|
425
|
+
- number = check_positive_float(number, "bias", "bias_class")
|
|
426
|
+
"""
|
|
427
|
+
return _check_number(arg_value, 0, GT, float, arg_name, prim_name)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
|
|
431
|
+
"""
|
|
432
|
+
Check argument is positive sequence, which mean all element > 0 in sequence.
|
|
433
|
+
|
|
434
|
+
Usage:
|
|
435
|
+
- sequence = check_positive_float_sequence(sequence)
|
|
436
|
+
- sequence = check_positive_float_sequence(sequence, "dims")
|
|
437
|
+
"""
|
|
438
|
+
for idx in range(len(sequence)):
|
|
439
|
+
element = sequence[idx]
|
|
440
|
+
arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
|
|
441
|
+
_check_number(element, 0, GT, float, arg_idx, prim_name)
|
|
442
|
+
return sequence
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def check_negative_float(arg_value, arg_name=None, prim_name=None):
|
|
446
|
+
"""
|
|
447
|
+
Check argument is negative float, which mean arg_value < 0.
|
|
448
|
+
|
|
449
|
+
Usage:
|
|
450
|
+
- number = check_negative_float(number)
|
|
451
|
+
- number = check_negative_float(number, "bias")
|
|
452
|
+
"""
|
|
453
|
+
return _check_number(arg_value, 0, LT, float, arg_name, prim_name)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
|
|
457
|
+
"""
|
|
458
|
+
Check argument is non-negative float, which mean arg_value <= 0.
|
|
459
|
+
|
|
460
|
+
Usage:
|
|
461
|
+
- number = check_non_positive_float(number)
|
|
462
|
+
- number = check_non_positive_float(number, "bias")
|
|
463
|
+
"""
|
|
464
|
+
return _check_number(arg_value, 0, LE, float, arg_name, prim_name)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
|
|
468
|
+
"""
|
|
469
|
+
Check argument is non-negative float, which mean arg_value >= 0.
|
|
470
|
+
|
|
471
|
+
Usage:
|
|
472
|
+
- number = check_non_negative_float(number)
|
|
473
|
+
- number = check_non_negative_float(number, "bias")
|
|
474
|
+
"""
|
|
475
|
+
return _check_number(arg_value, 0, GE, float, arg_name, prim_name)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def check_number(arg_name, arg_value, value, rel, prim_name):
|
|
479
|
+
"""Number value judgment."""
|
|
480
|
+
|
|
481
|
+
def _check():
|
|
482
|
+
if not _check_binary_rel(arg_value, value, rel):
|
|
483
|
+
rel_str = _format_str_one_value(value, rel)
|
|
484
|
+
raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
|
|
485
|
+
f'must {rel_str}, but got {arg_value}.')
|
|
486
|
+
|
|
487
|
+
_check()
|
|
488
|
+
return arg_value
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def check_isinstance(arg_name, arg_value, classes):
|
|
492
|
+
"""Check arg isinstance of classes"""
|
|
493
|
+
|
|
494
|
+
def _check():
|
|
495
|
+
if not isinstance(arg_value, classes):
|
|
496
|
+
raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
|
|
497
|
+
|
|
498
|
+
_check()
|
|
499
|
+
return arg_value
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def check_bool(arg_value, arg_name=None, prim_name=None):
|
|
503
|
+
"""
|
|
504
|
+
Check argument is instance of bool.
|
|
505
|
+
|
|
506
|
+
Usage:
|
|
507
|
+
- has_bias = check_bool(has_bias)
|
|
508
|
+
- has_bias = check_bool(has_bias, "has_bias")
|
|
509
|
+
"""
|
|
510
|
+
prim_name = f"For '{prim_name}', the" if prim_name else 'The'
|
|
511
|
+
arg_name = f"'{arg_name}'" if arg_name else 'input value'
|
|
512
|
+
|
|
513
|
+
def _check():
|
|
514
|
+
if not isinstance(arg_value, bool):
|
|
515
|
+
raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
|
|
516
|
+
|
|
517
|
+
_check()
|
|
518
|
+
return arg_value
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
522
|
+
"""
|
|
523
|
+
Method for checking whether input value is in int range.
|
|
524
|
+
|
|
525
|
+
Usage:
|
|
526
|
+
- number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1]
|
|
527
|
+
- number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1]
|
|
528
|
+
"""
|
|
529
|
+
return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
533
|
+
"""
|
|
534
|
+
Method for checking whether input value is in float range.
|
|
535
|
+
|
|
536
|
+
Usage:
|
|
537
|
+
- number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0]
|
|
538
|
+
- number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0]
|
|
539
|
+
"""
|
|
540
|
+
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
|
544
|
+
"""
|
|
545
|
+
Check whether string is in some value list.
|
|
546
|
+
|
|
547
|
+
Usage:
|
|
548
|
+
- method = check_string(method, ["string1", "string2", "string3"], "method")
|
|
549
|
+
"""
|
|
550
|
+
arg_name = arg_name if arg_name else "parameter"
|
|
551
|
+
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
552
|
+
|
|
553
|
+
def _check():
|
|
554
|
+
if not (isinstance(arg_value, str) and arg_value in valid_values):
|
|
555
|
+
raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
|
|
556
|
+
f" but got '{arg_value}'.")
|
|
557
|
+
|
|
558
|
+
_check()
|
|
559
|
+
return arg_value
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
|
563
|
+
if reg is None:
|
|
564
|
+
# Named string regular expression
|
|
565
|
+
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
|
|
566
|
+
if re.match(reg, target, flag) is None:
|
|
567
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
568
|
+
raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
|
|
569
|
+
return True
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
# pylint: disable=missing-docstring
|
|
573
|
+
def check_str_and_none_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
|
574
|
+
if reg is None:
|
|
575
|
+
# Named string regular expression
|
|
576
|
+
reg = r"^\w*[0-9a-zA-Z\_\.\-]*$"
|
|
577
|
+
if re.match(reg, target, flag) is None:
|
|
578
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
579
|
+
raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
|
|
580
|
+
return True
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def check_file_name_by_regular(target, reg=None, prim_name=None):
|
|
584
|
+
"""Check whether file name is legitimate."""
|
|
585
|
+
if not isinstance(target, str):
|
|
586
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
587
|
+
raise TypeError(f"{prim_name} '{target}' must be string, but got {type(target)}.")
|
|
588
|
+
if target.endswith("\\") or target.endswith("/"):
|
|
589
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
590
|
+
raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
|
|
591
|
+
if reg is None:
|
|
592
|
+
reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
|
|
593
|
+
if re.match(reg, target) is None:
|
|
594
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
595
|
+
raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular '{reg}'.")
|
|
596
|
+
|
|
597
|
+
return True
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
|
601
|
+
"""Validates value of padding according to pad_mode"""
|
|
602
|
+
if pad_mode != 'pad' and padding != 0:
|
|
603
|
+
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \
|
|
604
|
+
f" but got {padding}.")
|
|
605
|
+
return padding
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
|
|
609
|
+
"""Checks whether some type is subclass of another type"""
|
|
610
|
+
if not isinstance(template_types, Iterable):
|
|
611
|
+
template_types = (template_types,)
|
|
612
|
+
hit = False
|
|
613
|
+
for template_type in template_types:
|
|
614
|
+
if isinstance(template_type, mstype.Type):
|
|
615
|
+
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
|
|
616
|
+
hit = True
|
|
617
|
+
break
|
|
618
|
+
elif type_ is template_type:
|
|
619
|
+
hit = True
|
|
620
|
+
break
|
|
621
|
+
if not hit:
|
|
622
|
+
if addition_error_info is None:
|
|
623
|
+
addition_error_info = ''
|
|
624
|
+
else:
|
|
625
|
+
addition_error_info = ' ' + addition_error_info
|
|
626
|
+
type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
|
|
627
|
+
raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \
|
|
628
|
+
f" must be {'one of ' if len(template_types) > 1 else ''}" \
|
|
629
|
+
f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \
|
|
630
|
+
f"{addition_error_info}.The supported data types depend on the hardware that" \
|
|
631
|
+
f" executes the operator, for more details, please refer to the MindSpore official " \
|
|
632
|
+
f"website to get more information about the data type.")
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def check_valid_input(arg_name, arg_value, prim_name):
|
|
636
|
+
"""Checks valid value."""
|
|
637
|
+
|
|
638
|
+
def _check():
|
|
639
|
+
if arg_value is None:
|
|
640
|
+
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
|
|
641
|
+
f"can not be None, but got {arg_value}.")
|
|
642
|
+
|
|
643
|
+
_check()
|
|
644
|
+
return arg_value
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def check_types_same_and_valid(args, valid_values, prim_name):
|
|
648
|
+
"""Checks whether the types of inputs are the same and valid."""
|
|
649
|
+
|
|
650
|
+
def _check_type_valid(arg):
|
|
651
|
+
arg_key, arg_val = arg
|
|
652
|
+
elem_type = arg_val
|
|
653
|
+
check_subclass(arg_key, elem_type, valid_values, prim_name)
|
|
654
|
+
return (arg_key, elem_type)
|
|
655
|
+
|
|
656
|
+
def _check_types_same(arg1, arg2):
|
|
657
|
+
arg1_name, arg1_type = arg1
|
|
658
|
+
arg2_name, arg2_type = arg2
|
|
659
|
+
if arg1_type != arg2_type:
|
|
660
|
+
raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \
|
|
661
|
+
f" but got '{arg1_name}' with type {arg1_type}" \
|
|
662
|
+
f" and '{arg2_name}' with type {arg2_type}.")
|
|
663
|
+
return arg1
|
|
664
|
+
|
|
665
|
+
elem_types = map(_check_type_valid, args.items())
|
|
666
|
+
reduce(_check_types_same, elem_types)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
|
|
670
|
+
"""Checks whether the element types of input tensors are the same and valid."""
|
|
671
|
+
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
|
|
672
|
+
tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
|
|
673
|
+
check_types_same_and_valid(args, tensor_types, prim_name)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
|
|
677
|
+
"""Checks whether the element types of input tensors are valid."""
|
|
678
|
+
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
|
|
679
|
+
tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
|
|
680
|
+
check_subclass(arg_name, arg_type, tensor_types, prim_name)
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
|
|
684
|
+
"""
|
|
685
|
+
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
|
686
|
+
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
|
687
|
+
"""
|
|
688
|
+
|
|
689
|
+
def _check_argument_type(arg):
|
|
690
|
+
arg_key, arg_val = arg
|
|
691
|
+
if isinstance(arg_val, type(mstype.tensor_type)):
|
|
692
|
+
arg_val = arg_val.element_type()
|
|
693
|
+
if arg_val not in valid_values:
|
|
694
|
+
raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \
|
|
695
|
+
f' but got {arg_val}.')
|
|
696
|
+
return arg
|
|
697
|
+
|
|
698
|
+
def _check_types_same(arg1, arg2):
|
|
699
|
+
arg1_name, arg1_type = arg1
|
|
700
|
+
arg2_name, arg2_type = arg2
|
|
701
|
+
except_flag = False
|
|
702
|
+
if isinstance(arg1_type, type(mstype.tensor_type)) and isinstance(arg2_type, type(mstype.tensor_type)):
|
|
703
|
+
arg1_type = arg1_type.element_type()
|
|
704
|
+
arg2_type = arg2_type.element_type()
|
|
705
|
+
elif not (isinstance(arg1_type, type(mstype.tensor_type)) or isinstance(arg2_type, type(mstype.tensor_type))):
|
|
706
|
+
pass
|
|
707
|
+
elif allow_mix:
|
|
708
|
+
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor_type)) else arg1_type
|
|
709
|
+
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor_type)) else arg2_type
|
|
710
|
+
else:
|
|
711
|
+
except_flag = True
|
|
712
|
+
|
|
713
|
+
if except_flag or arg1_type != arg2_type:
|
|
714
|
+
raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \
|
|
715
|
+
f" but got '{arg1_name}' with type {arg1_type}" \
|
|
716
|
+
f" and '{arg2_name}' with type {arg2_type}.")
|
|
717
|
+
return arg1
|
|
718
|
+
|
|
719
|
+
args_map = map(_check_argument_type, args.items())
|
|
720
|
+
reduce(_check_types_same, args_map)
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
|
|
724
|
+
"""Checks whether a value is instance of some types."""
|
|
725
|
+
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
726
|
+
|
|
727
|
+
def raise_error_msg(cond, arg_value):
|
|
728
|
+
"""func for raising error message when check failed"""
|
|
729
|
+
if not cond:
|
|
730
|
+
return
|
|
731
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
|
732
|
+
num_types = len(valid_types)
|
|
733
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
734
|
+
raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
|
|
735
|
+
f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
|
|
736
|
+
f'but got type \'{type(arg_value).__name__}\'.')
|
|
737
|
+
|
|
738
|
+
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
|
739
|
+
# `check_value_type('x', True, [bool, int])` will check pass
|
|
740
|
+
cond = isinstance(arg_value, bool) and bool not in tuple(valid_types)
|
|
741
|
+
raise_error_msg(cond, arg_value)
|
|
742
|
+
if isinstance(arg_value, float) and float not in tuple(valid_types):
|
|
743
|
+
arg_value = round(arg_value, 6)
|
|
744
|
+
cond = not isinstance(arg_value, tuple(valid_types))
|
|
745
|
+
raise_error_msg(cond, arg_value)
|
|
746
|
+
return arg_value
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
|
750
|
+
"""Checks whether a type in some specified types"""
|
|
751
|
+
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
752
|
+
|
|
753
|
+
def raise_error_msg(cond, arg_type):
|
|
754
|
+
"""func for raising error message when check failed"""
|
|
755
|
+
if not cond:
|
|
756
|
+
return
|
|
757
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
|
|
758
|
+
num_types = len(valid_types)
|
|
759
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
760
|
+
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
761
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
762
|
+
f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
|
|
763
|
+
|
|
764
|
+
if isinstance(arg_type, type(mstype.tensor_type)):
|
|
765
|
+
arg_type = arg_type.element_type()
|
|
766
|
+
cond = arg_type not in valid_types
|
|
767
|
+
raise_error_msg(cond, arg_type)
|
|
768
|
+
return arg_type
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
|
|
772
|
+
"""Checks whether shape is ori_shape reduced on axis"""
|
|
773
|
+
axis_origin = axis
|
|
774
|
+
axis = axis if isinstance(axis, Iterable) else (axis,)
|
|
775
|
+
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
|
|
776
|
+
if list(shape) != exp_shape:
|
|
777
|
+
raise ValueError(f"For '{prim_name}', " \
|
|
778
|
+
f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \
|
|
779
|
+
f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def check_astype_dtype(dtype):
|
|
783
|
+
"""Check whether dtype is a valid input, and convert to mstype"""
|
|
784
|
+
all_types = mstype.__dtype__ + ["int", "float", "bool"]
|
|
785
|
+
if isinstance(dtype, str):
|
|
786
|
+
if dtype.lower() not in all_types:
|
|
787
|
+
raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
|
|
788
|
+
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
|
|
789
|
+
elif isinstance(dtype, type):
|
|
790
|
+
dtype = mstype.pytype_to_dtype(dtype)
|
|
791
|
+
elif not dtype in mstype.number_type + (mstype.bool_,):
|
|
792
|
+
raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \
|
|
793
|
+
f" but got '{dtype}'.")
|
|
794
|
+
return dtype
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def check_transpose_axis(axes, ndim):
|
|
798
|
+
"""Check the axis argument for tensor.transpose"""
|
|
799
|
+
|
|
800
|
+
def _check_dim():
|
|
801
|
+
# if multiple arguments provided, it must be `ndim` number of ints
|
|
802
|
+
if len(axes) != ndim:
|
|
803
|
+
raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \
|
|
804
|
+
f"but got {len(axes)} in the number of axes.")
|
|
805
|
+
|
|
806
|
+
if not axes or (len(axes) == 1 and axes[0] is None):
|
|
807
|
+
return tuple(range(ndim - 1, -1, -1))
|
|
808
|
+
|
|
809
|
+
if len(axes) == 1:
|
|
810
|
+
perm = axes[0]
|
|
811
|
+
# if only one argument provided, it must be tuple or list
|
|
812
|
+
if isinstance(perm, list):
|
|
813
|
+
perm = tuple(perm)
|
|
814
|
+
elif isinstance(perm, int):
|
|
815
|
+
perm = (perm,)
|
|
816
|
+
_check_dim()
|
|
817
|
+
else:
|
|
818
|
+
if not isinstance(perm, tuple):
|
|
819
|
+
raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
|
|
820
|
+
f"or series of integer, but got {type(axes[0])}")
|
|
821
|
+
return perm
|
|
822
|
+
|
|
823
|
+
_check_dim()
|
|
824
|
+
return axes
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def check_reshape_shp(shp):
|
|
828
|
+
"""Check the shape argument for tensor.reshape"""
|
|
829
|
+
|
|
830
|
+
if len(shp) == 1:
|
|
831
|
+
new_shape = shp[0]
|
|
832
|
+
# if only one argument provided, it must be int, tuple or list
|
|
833
|
+
if isinstance(new_shape, int):
|
|
834
|
+
return shp
|
|
835
|
+
if isinstance(new_shape, list):
|
|
836
|
+
new_shape = tuple(new_shape)
|
|
837
|
+
else:
|
|
838
|
+
if not isinstance(new_shape, tuple):
|
|
839
|
+
raise TypeError(
|
|
840
|
+
f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \
|
|
841
|
+
f"or series of integer, but got {type(shp[0])}")
|
|
842
|
+
return new_shape
|
|
843
|
+
|
|
844
|
+
return shp
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def check_flatten_order(order):
|
|
848
|
+
"""Check flatten function input order"""
|
|
849
|
+
if not isinstance(order, str):
|
|
850
|
+
raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
|
|
851
|
+
if order not in ('C', 'F'):
|
|
852
|
+
raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
def check_swapaxes_axis(axes, ndim):
|
|
856
|
+
"""Check all the axes argument for ops.swapaxes"""
|
|
857
|
+
if isinstance(axes, int):
|
|
858
|
+
return check_axis_in_range(axes, ndim)
|
|
859
|
+
if isinstance(axes, (tuple, list)):
|
|
860
|
+
for axis in axes:
|
|
861
|
+
if not isinstance(axis, int):
|
|
862
|
+
raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
|
|
863
|
+
check_axis_in_range(axis, ndim)
|
|
864
|
+
tmp = ()
|
|
865
|
+
for x in axes:
|
|
866
|
+
tmp = tmp + ((x + ndim) % ndim,)
|
|
867
|
+
return tmp
|
|
868
|
+
raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \
|
|
869
|
+
f"but got {type(axes)}.")
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def prepare_shape_for_squeeze(shape, axes):
|
|
873
|
+
"""
|
|
874
|
+
Creates the squeezed new shape based on the tensor and given axes.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
shape (tuple): the shape of the tensor
|
|
878
|
+
axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
|
|
879
|
+
be squeezed.
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
new_shape(tuple): the shape with dimensions squeezed.
|
|
883
|
+
"""
|
|
884
|
+
new_shape = ()
|
|
885
|
+
ndim = len(shape)
|
|
886
|
+
|
|
887
|
+
def _check(axes, ndim):
|
|
888
|
+
if axes >= ndim or axes < -ndim:
|
|
889
|
+
raise ValueError(f"For Tensor.squeeze, the 'axis' must be in the range of [-{ndim}, {ndim}), " \
|
|
890
|
+
f"but got {axes}.")
|
|
891
|
+
|
|
892
|
+
def _check_for(axes, ndim):
|
|
893
|
+
for axis in axes:
|
|
894
|
+
_check(axis, ndim)
|
|
895
|
+
|
|
896
|
+
if isinstance(axes, int):
|
|
897
|
+
_check(axes, ndim)
|
|
898
|
+
axes = (axes,)
|
|
899
|
+
elif isinstance(axes, (list, tuple)):
|
|
900
|
+
_check_for(axes, ndim)
|
|
901
|
+
new_axes = ()
|
|
902
|
+
for item in axes:
|
|
903
|
+
if item not in new_axes:
|
|
904
|
+
new_axes += (item,)
|
|
905
|
+
axes = new_axes
|
|
906
|
+
else:
|
|
907
|
+
raise TypeError(f"For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], " \
|
|
908
|
+
f"but got {type(axes)}")
|
|
909
|
+
|
|
910
|
+
def _check_axis(s, idx, axes, ndim):
|
|
911
|
+
# if an axis is selected with shape entry greater than one, an error is raised.
|
|
912
|
+
if s != 1 and ((idx in axes) or (idx - ndim in axes)):
|
|
913
|
+
raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
|
|
914
|
+
|
|
915
|
+
for idx in range(ndim):
|
|
916
|
+
s = shape[idx]
|
|
917
|
+
_check_axis(s, idx, axes, ndim)
|
|
918
|
+
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
|
|
919
|
+
new_shape = new_shape + (s,)
|
|
920
|
+
|
|
921
|
+
return new_shape
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def check_axis_in_range(axis, ndim):
|
|
925
|
+
"""Checks axes are with the bounds of ndim"""
|
|
926
|
+
|
|
927
|
+
def _check():
|
|
928
|
+
if not isinstance(axis, int):
|
|
929
|
+
raise TypeError(f'The axes must be integers, but got {type(axis)}')
|
|
930
|
+
|
|
931
|
+
if axis >= ndim or axis < -ndim:
|
|
932
|
+
raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
|
|
933
|
+
|
|
934
|
+
_check()
|
|
935
|
+
return (axis + ndim) % ndim
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def check_axis_valid(axes, ndim):
|
|
939
|
+
"""
|
|
940
|
+
Checks axes are valid given ndim, and returns axes that can be passed
|
|
941
|
+
to the built-in operator (non-negative, int or tuple)
|
|
942
|
+
"""
|
|
943
|
+
|
|
944
|
+
def _check_range(axes):
|
|
945
|
+
for axis in axes:
|
|
946
|
+
check_axis_in_range(axis, ndim)
|
|
947
|
+
|
|
948
|
+
if axes is None:
|
|
949
|
+
axes = tuple(range(ndim))
|
|
950
|
+
return axes
|
|
951
|
+
if isinstance(axes, (tuple, list)):
|
|
952
|
+
_check_range(axes)
|
|
953
|
+
tmp = ()
|
|
954
|
+
for x in axes:
|
|
955
|
+
tmp = tmp + ((x + ndim) % ndim,)
|
|
956
|
+
_check_dup(tmp)
|
|
957
|
+
return tmp
|
|
958
|
+
check_axis_in_range(axes, ndim)
|
|
959
|
+
return (axes % ndim,)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
def max_(*args):
|
|
963
|
+
"""Return the maximum value of the input parameter."""
|
|
964
|
+
return max(*args)
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def min_(*args):
|
|
968
|
+
"""Return the minimum value of the input parameter."""
|
|
969
|
+
return min(*args)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def is_stub_tensor(tensor):
|
|
973
|
+
return hasattr(tensor, "stub")
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def expanded_shape(ndim, axis_size, axis):
|
|
977
|
+
"""
|
|
978
|
+
Returns a shape with size = 1 for all dimensions
|
|
979
|
+
except at axis.
|
|
980
|
+
"""
|
|
981
|
+
return tuple(axis_size if i == axis else 1 for i in range(ndim))
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def tuple_slice(tup, start, end):
|
|
985
|
+
"""get sliced tuple from start and end."""
|
|
986
|
+
return tup[start:end]
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def infer_out_shape(*shapes):
|
|
990
|
+
"""
|
|
991
|
+
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
|
992
|
+
"""
|
|
993
|
+
|
|
994
|
+
def _check(items, max_size, shapes):
|
|
995
|
+
for item in items:
|
|
996
|
+
if item not in (1, max_size):
|
|
997
|
+
raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
|
|
998
|
+
f'to support broadcasting, but got shapes {shapes,}')
|
|
999
|
+
|
|
1000
|
+
shape_out = ()
|
|
1001
|
+
max_len = max([len(it) for it in shapes])
|
|
1002
|
+
for i in range(max_len):
|
|
1003
|
+
items = [it[i - (max_len - len(it))] if i - (max_len - len(it)) >= 0 else 1 for it in shapes]
|
|
1004
|
+
max_size = 0 if 0 in items else max(items)
|
|
1005
|
+
_check(items, max_size, shapes)
|
|
1006
|
+
shape_out = shape_out + (max_size,)
|
|
1007
|
+
return shape_out
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
|
1011
|
+
"""Check axis argument type."""
|
|
1012
|
+
if type_int and isinstance(axis, int):
|
|
1013
|
+
return True
|
|
1014
|
+
if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
|
|
1015
|
+
for ax in axis:
|
|
1016
|
+
if not isinstance(ax, int):
|
|
1017
|
+
raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
|
|
1018
|
+
return True
|
|
1019
|
+
|
|
1020
|
+
type_str = ""
|
|
1021
|
+
if type_int:
|
|
1022
|
+
type_str += "int, "
|
|
1023
|
+
if type_tuple:
|
|
1024
|
+
type_str += "tuple, "
|
|
1025
|
+
if type_list:
|
|
1026
|
+
type_str += "list, "
|
|
1027
|
+
raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def check_and_canonicalize_axes(axes, ndim):
|
|
1031
|
+
"""Check whether the types and values of input axes are valid."""
|
|
1032
|
+
|
|
1033
|
+
def _check(axes, ax, ndim):
|
|
1034
|
+
if not isinstance(ax, int):
|
|
1035
|
+
raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
|
|
1036
|
+
if ax >= ndim or ax < -ndim:
|
|
1037
|
+
raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
|
|
1038
|
+
|
|
1039
|
+
axes = axes if isinstance(axes, tuple) else (axes,)
|
|
1040
|
+
new_axes = ()
|
|
1041
|
+
for ax in axes:
|
|
1042
|
+
_check(axes, ax, ndim)
|
|
1043
|
+
ax = ax if ax >= 0 else ax + ndim
|
|
1044
|
+
new_axes += (ax,)
|
|
1045
|
+
_check_dup(new_axes)
|
|
1046
|
+
return new_axes
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
def check_type_support(dtype, device, supported_dtypes):
|
|
1050
|
+
"""Checks whether the data type is supported."""
|
|
1051
|
+
return dtype in supported_dtypes or not context.get_context('device_target') == device
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def check_sparse_tensor_input(indices, values, shape):
|
|
1055
|
+
"""Common input check for SparseTensors."""
|
|
1056
|
+
if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
|
|
1057
|
+
raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
|
|
1058
|
+
if not isinstance(values, Tensor_) and not is_stub_tensor(values):
|
|
1059
|
+
raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
|
|
1060
|
+
if not isinstance(shape, tuple):
|
|
1061
|
+
raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def check_csr_tensor_input(indptr, indices, values, shape):
|
|
1065
|
+
"""Checks inputs type for CSRTensor."""
|
|
1066
|
+
if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
|
|
1067
|
+
raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
|
|
1068
|
+
check_sparse_tensor_input(indices, values, shape)
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
|
|
1072
|
+
"""Checks input tensors' shapes for CSRTensor."""
|
|
1073
|
+
# Support empty sparse tensor
|
|
1074
|
+
if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)):
|
|
1075
|
+
return
|
|
1076
|
+
shape_size = 1
|
|
1077
|
+
val_shp_size = 1
|
|
1078
|
+
for item in csr_shp:
|
|
1079
|
+
if item <= 0:
|
|
1080
|
+
raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
|
|
1081
|
+
if not isinstance(item, int):
|
|
1082
|
+
raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
|
|
1083
|
+
shape_size *= item
|
|
1084
|
+
for item in values_shp:
|
|
1085
|
+
if item <= 0:
|
|
1086
|
+
raise ValueError(f"The element of shape must be positive, but got {item}")
|
|
1087
|
+
val_shp_size *= item
|
|
1088
|
+
if shape_size < val_shp_size:
|
|
1089
|
+
raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
|
|
1090
|
+
if len(indices_shp) != 1:
|
|
1091
|
+
raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \
|
|
1092
|
+
f"but got a {len(indices_shp)} dimension tensor.")
|
|
1093
|
+
if len(indptr_shp) != 1:
|
|
1094
|
+
raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \
|
|
1095
|
+
f"but got a {len(indptr_shp)} dimension tensor.")
|
|
1096
|
+
if csr_shp[0] + 1 != indptr_shp[0]:
|
|
1097
|
+
raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \
|
|
1098
|
+
f"but got: {indptr_shp[0]}")
|
|
1099
|
+
if indices_shp[0] != values_shp[0]:
|
|
1100
|
+
err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
|
|
1101
|
+
err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
|
|
1102
|
+
raise ValueError(err_msg1 + err_msg2)
|
|
1103
|
+
if len(values_shp) + 1 != len(csr_shp):
|
|
1104
|
+
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \
|
|
1105
|
+
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \
|
|
1106
|
+
f"{len(csr_shp)}")
|
|
1107
|
+
if values_shp[1:] != csr_shp[2:]:
|
|
1108
|
+
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
|
|
1109
|
+
f"but CSRTensor's shape[2: ] got: {csr_shp[2:]} and value's shape[1: ]" \
|
|
1110
|
+
f"got: {values_shp[1:]}")
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
|
|
1114
|
+
"""Checks input tensors' data types for CSRTensor."""
|
|
1115
|
+
if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1116
|
+
raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \
|
|
1117
|
+
f"but got {indptr_dtype}.")
|
|
1118
|
+
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1119
|
+
raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \
|
|
1120
|
+
f"but got {indices_dtype}.")
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
def check_coo_tensor_input(indices, values, shape):
|
|
1124
|
+
"""Checks inputs type for COOTensor."""
|
|
1125
|
+
check_sparse_tensor_input(indices, values, shape)
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
|
|
1129
|
+
"""Checks input tensors' shapes for COOTensor."""
|
|
1130
|
+
if len(coo_shp) != 2:
|
|
1131
|
+
raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
|
|
1132
|
+
if (indices_shp == (0,)) and (values_shp == (0,)):
|
|
1133
|
+
return
|
|
1134
|
+
shp_mul = 1
|
|
1135
|
+
for sh in coo_shp:
|
|
1136
|
+
if sh <= 0:
|
|
1137
|
+
raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
|
|
1138
|
+
if not isinstance(sh, int):
|
|
1139
|
+
raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
|
|
1140
|
+
shp_mul *= sh
|
|
1141
|
+
if shp_mul < values_shp[0]:
|
|
1142
|
+
raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
|
|
1143
|
+
if len(indices_shp) != 2:
|
|
1144
|
+
raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \
|
|
1145
|
+
f"-dimensional tensor.")
|
|
1146
|
+
if len(values_shp) != 1:
|
|
1147
|
+
raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \
|
|
1148
|
+
f"-dimensional tensor.")
|
|
1149
|
+
if indices_shp[0] != values_shp[0]:
|
|
1150
|
+
raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \
|
|
1151
|
+
f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
|
|
1152
|
+
if indices_shp[1] != 2:
|
|
1153
|
+
raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
def check_coo_tensor_dtype(indices_dtype):
|
|
1157
|
+
"""Checks input tensors' data types for COOTensor."""
|
|
1158
|
+
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1159
|
+
raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \
|
|
1160
|
+
f"{indices_dtype}.")
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
|
|
1164
|
+
"""Check type of the element of a iterabel object, except dict."""
|
|
1165
|
+
check_value_type(arg_name, arg_value, [list, tuple], prim_name)
|
|
1166
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
|
1167
|
+
num_types = len(valid_types)
|
|
1168
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1169
|
+
for element in arg_value:
|
|
1170
|
+
if not isinstance(element, tuple(valid_types)):
|
|
1171
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1172
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1173
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1174
|
+
|
|
1175
|
+
|
|
1176
|
+
def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
|
|
1177
|
+
"""Check the type of key and value of a dict."""
|
|
1178
|
+
check_value_type(arg_name, arg_value, [dict], prim_name)
|
|
1179
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1180
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
|
|
1181
|
+
num_types = len(key_types)
|
|
1182
|
+
for element in arg_value.keys():
|
|
1183
|
+
if not isinstance(element, tuple(key_types)):
|
|
1184
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1185
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1186
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1187
|
+
|
|
1188
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
|
|
1189
|
+
num_types = len(value_types)
|
|
1190
|
+
for element in arg_value.values():
|
|
1191
|
+
if not isinstance(element, tuple(value_types)):
|
|
1192
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1193
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1194
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
|
|
1198
|
+
"""Check the size and element type of a tuple."""
|
|
1199
|
+
check_value_type(arg_name, arg_value, [tuple], prim_name)
|
|
1200
|
+
check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
|
|
1201
|
+
check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
def _check_symbol(dyn_input, net_input, index, symbolic_shape_data):
|
|
1205
|
+
"""Check symbolic shape values."""
|
|
1206
|
+
actual_shape = net_input.shape
|
|
1207
|
+
for i, sym in enumerate(dyn_input.symbolic_shape):
|
|
1208
|
+
# the Symbol is converted to dict
|
|
1209
|
+
if not isinstance(sym, dict):
|
|
1210
|
+
continue
|
|
1211
|
+
# the value of symbols with same "id" should be equal.
|
|
1212
|
+
if "id" in sym:
|
|
1213
|
+
sym_id = sym["id"]
|
|
1214
|
+
k_idval = "unique_id_value_map"
|
|
1215
|
+
if k_idval not in symbolic_shape_data:
|
|
1216
|
+
symbolic_shape_data[k_idval] = {}
|
|
1217
|
+
unique_id_value = symbolic_shape_data[k_idval]
|
|
1218
|
+
if sym_id not in unique_id_value:
|
|
1219
|
+
unique_id_value[sym_id] = actual_shape[i]
|
|
1220
|
+
elif unique_id_value[sym_id] != actual_shape[i]:
|
|
1221
|
+
raise ValueError(
|
|
1222
|
+
f"The {i + 1}th shape value of {index + 1}th actual input args is a unique symbol, all values must "
|
|
1223
|
+
f"be the same. The previous value is {unique_id_value[sym_id]}, but the current value is "
|
|
1224
|
+
f"{actual_shape[i]}. Actual shape: {actual_shape}, axis: {i}.")
|
|
1225
|
+
# check the value in range [min, max].
|
|
1226
|
+
if "min" in sym and actual_shape[i] < sym["min"]:
|
|
1227
|
+
raise ValueError(
|
|
1228
|
+
f"The {i + 1}th shape value of {index + 1}th actual input args must be greater than or equal to the "
|
|
1229
|
+
f"'min' value '{sym['min']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, "
|
|
1230
|
+
f"axis: {i}.")
|
|
1231
|
+
if "max" in sym and actual_shape[i] > sym["max"]:
|
|
1232
|
+
raise ValueError(
|
|
1233
|
+
f"The {i + 1}th shape value of {index + 1}th actual input args must be less than or equal to the "
|
|
1234
|
+
f"'max' value '{sym['max']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, "
|
|
1235
|
+
f"axis: {i}.")
|
|
1236
|
+
# check the shape item that satisfies the "divisor * N + remainder, N >= 1".
|
|
1237
|
+
d = sym.get("divisor", 1)
|
|
1238
|
+
r = sym.get("remainder", 0)
|
|
1239
|
+
if actual_shape[i] < d or actual_shape[i] % d != r:
|
|
1240
|
+
raise ValueError(
|
|
1241
|
+
f"The {i + 1}th shape value of {index + 1}th actual input args must be match the 'divisor'(d) and "
|
|
1242
|
+
f"'remainder'(r) of `Symbol`. The value should be 'd * N + r' for 'N > 0', got d={d} and r={r}, but "
|
|
1243
|
+
f"actual shape value is '{actual_shape[i]}'. Actual shape: {actual_shape}, axis: {i}")
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
def check_symbolic_shape(dynamic_inputs, actual_inputs):
|
|
1247
|
+
"""Check the symboic shape"""
|
|
1248
|
+
symbolic_shape_data = {}
|
|
1249
|
+
|
|
1250
|
+
def run_check(dyn_inputs, net_inputs):
|
|
1251
|
+
"""the real checking function"""
|
|
1252
|
+
for index, (dyn_input, net_input) in enumerate(zip(dyn_inputs, net_inputs)):
|
|
1253
|
+
if isinstance(dyn_input, (tuple, list)):
|
|
1254
|
+
run_check(dyn_input, net_input)
|
|
1255
|
+
elif hasattr(dyn_input, "symbolic_shape"):
|
|
1256
|
+
_check_symbol(dyn_input, net_input, index, symbolic_shape_data)
|
|
1257
|
+
|
|
1258
|
+
run_check(dynamic_inputs, actual_inputs)
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
def check_input_format(input_param):
|
|
1262
|
+
"""Judge input format."""
|
|
1263
|
+
if input_param == "NCHW":
|
|
1264
|
+
return input_param
|
|
1265
|
+
raise ValueError(f"The data format must be NCHW, but got {input_param}.")
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def _expand_tuple(n_dimensions):
|
|
1269
|
+
"""To expand an int number to tuple."""
|
|
1270
|
+
|
|
1271
|
+
def convert(m):
|
|
1272
|
+
if not isinstance(m, tuple):
|
|
1273
|
+
if isinstance(m, int) and not isinstance(m, bool):
|
|
1274
|
+
return tuple(repeat(m, n_dimensions))
|
|
1275
|
+
raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \
|
|
1276
|
+
f"but got {type(m)}")
|
|
1277
|
+
|
|
1278
|
+
if not len(m) is n_dimensions:
|
|
1279
|
+
raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \
|
|
1280
|
+
f"but got {m}")
|
|
1281
|
+
|
|
1282
|
+
for i in m:
|
|
1283
|
+
if not isinstance(i, int) or isinstance(i, bool):
|
|
1284
|
+
raise TypeError(f"When expanding an int number to tuple, " \
|
|
1285
|
+
f"the type of element in input tuple must be an integer, but got {type(i)}.")
|
|
1286
|
+
return m
|
|
1287
|
+
|
|
1288
|
+
return convert
|
|
1289
|
+
|
|
1290
|
+
|
|
1291
|
+
def _check_data_type_valid(data, valid_type):
|
|
1292
|
+
"""Check data type valid."""
|
|
1293
|
+
if valid_type is None:
|
|
1294
|
+
return data is None
|
|
1295
|
+
if isinstance(data, valid_type):
|
|
1296
|
+
if hasattr(data, 'size') and data.size == 0:
|
|
1297
|
+
msg = "The input data can not be empty."
|
|
1298
|
+
logger.critical(msg)
|
|
1299
|
+
raise ValueError(msg)
|
|
1300
|
+
return True
|
|
1301
|
+
return False
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
def check_input_data(*data, data_class):
|
|
1305
|
+
"""Input data check."""
|
|
1306
|
+
for item in data:
|
|
1307
|
+
if isinstance(item, (list, tuple)):
|
|
1308
|
+
for v in item:
|
|
1309
|
+
check_input_data(v, data_class=data_class)
|
|
1310
|
+
elif isinstance(item, dict):
|
|
1311
|
+
for v in item.values():
|
|
1312
|
+
check_input_data(v, data_class=data_class)
|
|
1313
|
+
else:
|
|
1314
|
+
if isinstance(data_class, (tuple, list)):
|
|
1315
|
+
ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
|
|
1316
|
+
else:
|
|
1317
|
+
ret = _check_data_type_valid(item, data_class)
|
|
1318
|
+
if not ret:
|
|
1319
|
+
data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
|
|
1320
|
+
data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
|
|
1321
|
+
raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \
|
|
1322
|
+
f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
|
|
1323
|
+
f'but got type {item if item is None else type(item).__name__}.')
|
|
1324
|
+
|
|
1325
|
+
|
|
1326
|
+
def check_input_dataset(*dataset, dataset_type):
|
|
1327
|
+
"""Input dataset check."""
|
|
1328
|
+
if not dataset:
|
|
1329
|
+
return False
|
|
1330
|
+
for item in dataset:
|
|
1331
|
+
if not isinstance(item, dataset_type):
|
|
1332
|
+
return False
|
|
1333
|
+
return True
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
def check_output_data(data):
|
|
1337
|
+
"""Output data check."""
|
|
1338
|
+
if data is None:
|
|
1339
|
+
raise RuntimeError('The output data can not be None, please check your net or input data.')
|
|
1340
|
+
|
|
1341
|
+
|
|
1342
|
+
once = _expand_tuple(1)
|
|
1343
|
+
twice = _expand_tuple(2)
|
|
1344
|
+
triple = _expand_tuple(3)
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
def args_type_check(*type_args, **type_kwargs):
|
|
1348
|
+
"""Check whether input data type is correct."""
|
|
1349
|
+
|
|
1350
|
+
def type_check(func):
|
|
1351
|
+
sig = inspect.signature(func)
|
|
1352
|
+
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
|
1353
|
+
|
|
1354
|
+
@wraps(func)
|
|
1355
|
+
def wrapper(*args, **kwargs):
|
|
1356
|
+
nonlocal bound_types
|
|
1357
|
+
bound_values = sig.bind(*args, **kwargs)
|
|
1358
|
+
argument_dict = bound_values.arguments
|
|
1359
|
+
if "kwargs" in bound_types:
|
|
1360
|
+
bound_types = bound_types["kwargs"]
|
|
1361
|
+
if "kwargs" in argument_dict:
|
|
1362
|
+
argument_dict = argument_dict["kwargs"]
|
|
1363
|
+
for name, value in argument_dict.items():
|
|
1364
|
+
if name in bound_types:
|
|
1365
|
+
if value is not None and not isinstance(value, bound_types[name]):
|
|
1366
|
+
raise TypeError(f"The parameter '{name}' must be {bound_types[name]}, but got {type(value)}")
|
|
1367
|
+
return func(*args, **kwargs)
|
|
1368
|
+
|
|
1369
|
+
return wrapper
|
|
1370
|
+
|
|
1371
|
+
return type_check
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
def check_hook_fn(hook_type, hook_fn):
|
|
1375
|
+
"""Check hook fn"""
|
|
1376
|
+
if context.get_context("mode") != context.PYNATIVE_MODE:
|
|
1377
|
+
logger.warning(f"'{hook_type}' function is only supported in pynative mode, you can use "
|
|
1378
|
+
f"context.set_context to set pynative mode.")
|
|
1379
|
+
return False
|
|
1380
|
+
|
|
1381
|
+
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1382
|
+
raise TypeError(f"When using 'hook_type(hook_fn)', the type of 'hook_fn' must be python "
|
|
1383
|
+
f"function, but got {type(hook_fn)}.")
|
|
1384
|
+
|
|
1385
|
+
if hook_fn.__code__.co_name == "staging_specialize":
|
|
1386
|
+
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
1387
|
+
|
|
1388
|
+
tensor_hook_func_args_num = 1
|
|
1389
|
+
pre_hook_func_args_num = 2
|
|
1390
|
+
forward_hook_and_backward_hook_func_args_num = 3
|
|
1391
|
+
# Real args number, exclude class method self param
|
|
1392
|
+
hook_fn_args_num = len(inspect.signature(hook_fn).parameters)
|
|
1393
|
+
|
|
1394
|
+
if hook_type == "register_hook" and hook_fn_args_num != tensor_hook_func_args_num:
|
|
1395
|
+
raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num should be {tensor_hook_func_args_num}, but "
|
|
1396
|
+
f"got {hook_fn_args_num}")
|
|
1397
|
+
|
|
1398
|
+
if hook_type == "register_forward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
|
|
1399
|
+
raise TypeError(f"forward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num}, "
|
|
1400
|
+
f"but got {hook_fn_args_num}")
|
|
1401
|
+
|
|
1402
|
+
if (hook_type == "register_forward_hook" and
|
|
1403
|
+
hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
|
|
1404
|
+
raise TypeError(f"forward_hook function {hook_fn.__name__} args num should be "
|
|
1405
|
+
f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
|
|
1406
|
+
|
|
1407
|
+
if hook_type == "register_backward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
|
|
1408
|
+
raise TypeError(f"backward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num},"
|
|
1409
|
+
f" but got {hook_fn_args_num}")
|
|
1410
|
+
|
|
1411
|
+
if (hook_type == "register_backward_hook" and
|
|
1412
|
+
hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
|
|
1413
|
+
raise TypeError(f"backward_hook function {hook_fn.__name__} args num should be "
|
|
1414
|
+
f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
|
|
1415
|
+
|
|
1416
|
+
return True
|
|
1417
|
+
|
|
1418
|
+
|
|
1419
|
+
_set_record = {}
|