mindspore 2.4.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
- mindspore/_c_expression.cpython-310-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,880 @@
|
|
|
1
|
+
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Initializer for cell parameters."""
|
|
16
|
+
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
|
|
19
|
+
import numbers
|
|
20
|
+
import math
|
|
21
|
+
|
|
22
|
+
from functools import reduce
|
|
23
|
+
import numpy as np
|
|
24
|
+
from mindspore.common.seed import get_seed, _get_graph_seed
|
|
25
|
+
from mindspore.common import dtype as mstype
|
|
26
|
+
from mindspore.common.tensor import Tensor
|
|
27
|
+
from mindspore._c_expression import _random_normal, _random_uniform, _truncated_normal
|
|
28
|
+
|
|
29
|
+
_INITIALIZER_ALIAS = dict()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Initializer:
|
|
33
|
+
"""
|
|
34
|
+
The abstract base class of the initializer.
|
|
35
|
+
|
|
36
|
+
Note:
|
|
37
|
+
Initializers are intended to be used for delayed initialization in parallel mode rather than Tensor
|
|
38
|
+
initialization. If you have to use Initializers to create a Tensor, :func:`mindspore.Tensor.init_data` should be
|
|
39
|
+
followed in most of the cases. For more information, please refer to `mindspore.Tensor.init_data
|
|
40
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/Tensor/mindspore.Tensor.init_data.html#
|
|
41
|
+
mindspore-tensor-init-data>`_ .
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
kwargs (dict): Keyword arguments for Initializer.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, **kwargs):
|
|
48
|
+
self._kwargs = kwargs
|
|
49
|
+
self._seed = None
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def seed(self):
|
|
53
|
+
if self._seed is None:
|
|
54
|
+
seed, seed2 = _get_graph_seed(get_seed(), "init")
|
|
55
|
+
else:
|
|
56
|
+
seed, seed2 = self._seed + 1, 0
|
|
57
|
+
return seed, seed2
|
|
58
|
+
|
|
59
|
+
@seed.setter
|
|
60
|
+
def seed(self, value):
|
|
61
|
+
self._seed = value
|
|
62
|
+
|
|
63
|
+
def _initialize(self, *kwargs):
|
|
64
|
+
raise NotImplementedError('Must be overridden!')
|
|
65
|
+
|
|
66
|
+
def __call__(self, arr):
|
|
67
|
+
return self._initialize(arr)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _register(*aliases):
|
|
71
|
+
"""Return the alias register."""
|
|
72
|
+
def alias_reg(cls):
|
|
73
|
+
name = cls.__name__
|
|
74
|
+
name = name.lower()
|
|
75
|
+
if name not in _INITIALIZER_ALIAS:
|
|
76
|
+
_INITIALIZER_ALIAS[name] = cls
|
|
77
|
+
|
|
78
|
+
for alias in aliases:
|
|
79
|
+
if alias not in _INITIALIZER_ALIAS:
|
|
80
|
+
_INITIALIZER_ALIAS[alias] = cls
|
|
81
|
+
|
|
82
|
+
return cls
|
|
83
|
+
|
|
84
|
+
return alias_reg
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _assignment(arr, num):
|
|
88
|
+
"""Assign the value of `num` to `arr`."""
|
|
89
|
+
if arr.shape == ():
|
|
90
|
+
arr = arr.reshape(1)
|
|
91
|
+
arr[:] = num
|
|
92
|
+
arr = arr.reshape(())
|
|
93
|
+
else:
|
|
94
|
+
if isinstance(num, np.ndarray):
|
|
95
|
+
arr[:] = num[:]
|
|
96
|
+
else:
|
|
97
|
+
arr[:] = num
|
|
98
|
+
return arr
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _numpy_seed():
|
|
102
|
+
# This will produce same value after call numpy.random.seed with same seed.
|
|
103
|
+
return np.random.randint(low=1, high=(1 << 63), dtype=np.int64)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _init_random_normal(mean, sigma, shape):
|
|
107
|
+
if sigma < 0:
|
|
108
|
+
raise ValueError("sigma < 0")
|
|
109
|
+
data = np.ndarray(shape=shape, dtype=np.float32)
|
|
110
|
+
_random_normal(_numpy_seed(), data, mean, sigma)
|
|
111
|
+
return data
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _init_random_uniform(a, b, shape):
|
|
115
|
+
data = np.ndarray(shape=shape, dtype=np.float32)
|
|
116
|
+
_random_uniform(_numpy_seed(), data, a, b)
|
|
117
|
+
return data
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _init_truncated_normal(a, b, mean, sigma, shape):
|
|
121
|
+
if sigma < 0:
|
|
122
|
+
raise ValueError("sigma < 0")
|
|
123
|
+
data = np.ndarray(shape=shape, dtype=np.float32)
|
|
124
|
+
_truncated_normal(_numpy_seed(), data, a, b, mean, sigma)
|
|
125
|
+
return data
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@_register('zeros')
|
|
129
|
+
class Zero(Initializer):
|
|
130
|
+
"""
|
|
131
|
+
Generates an array with constant value of zero in order to initialize a tensor.
|
|
132
|
+
|
|
133
|
+
Examples:
|
|
134
|
+
>>> import mindspore
|
|
135
|
+
>>> from mindspore.common.initializer import initializer, Zero
|
|
136
|
+
>>> from mindspore import Parameter
|
|
137
|
+
>>> w1 = Parameter(initializer(Zero(), [1, 2, 3], mindspore.float32))
|
|
138
|
+
>>> w2 = Parameter(initializer('zeros', [1, 2, 3], mindspore.float32))
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def _initialize(self, arr):
|
|
142
|
+
arr.fill(0)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@_register('ones')
|
|
146
|
+
class One(Initializer):
|
|
147
|
+
"""
|
|
148
|
+
Generates an array with constant value of one in order to initialize a tensor.
|
|
149
|
+
|
|
150
|
+
Examples:
|
|
151
|
+
>>> import mindspore
|
|
152
|
+
>>> from mindspore.common.initializer import initializer, One
|
|
153
|
+
>>> from mindspore import Parameter
|
|
154
|
+
>>> w1 = Parameter(initializer(One(), [1, 2, 3], mindspore.float32))
|
|
155
|
+
>>> w2 = Parameter(initializer('ones', [1, 2, 3], mindspore.float32))
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def _initialize(self, arr):
|
|
159
|
+
arr.fill(1)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _calculate_fan_in_and_fan_out(shape):
|
|
163
|
+
"""
|
|
164
|
+
calculate fan_in and fan_out
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
shape (tuple): input shape.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
|
171
|
+
"""
|
|
172
|
+
dimensions = len(shape)
|
|
173
|
+
if dimensions < 2:
|
|
174
|
+
raise ValueError("'fan_in' and 'fan_out' can not be computed for tensor with fewer than"
|
|
175
|
+
" 2 dimensions, but got dimensions {}.".format(dimensions))
|
|
176
|
+
if dimensions == 2: # Linear
|
|
177
|
+
fan_in = shape[1]
|
|
178
|
+
fan_out = shape[0]
|
|
179
|
+
else:
|
|
180
|
+
num_input_fmaps = shape[1]
|
|
181
|
+
num_output_fmaps = shape[0]
|
|
182
|
+
receptive_field_size = 1
|
|
183
|
+
for i in range(2, dimensions):
|
|
184
|
+
receptive_field_size *= shape[i]
|
|
185
|
+
fan_in = num_input_fmaps * receptive_field_size
|
|
186
|
+
fan_out = num_output_fmaps * receptive_field_size
|
|
187
|
+
return fan_in, fan_out
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _calculate_correct_fan(shape, mode):
|
|
191
|
+
"""
|
|
192
|
+
Calculate fan.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
shape (tuple): input shape.
|
|
196
|
+
mode (str): only support fan_in and fan_out.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
fan_in or fan_out.
|
|
200
|
+
"""
|
|
201
|
+
mode = mode.lower()
|
|
202
|
+
valid_modes = ['fan_in', 'fan_out']
|
|
203
|
+
if mode not in valid_modes:
|
|
204
|
+
raise ValueError("'mode' {} not supported, please use one of {}".format(mode, valid_modes))
|
|
205
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(shape)
|
|
206
|
+
return fan_in if mode == 'fan_in' else fan_out
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _calculate_gain(nonlinearity, param=None):
|
|
210
|
+
"""
|
|
211
|
+
Calculate gain.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
nonlinearity (str): nonlinearity function.
|
|
215
|
+
param (str): used to calculate negative_slope.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
number.
|
|
219
|
+
"""
|
|
220
|
+
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
|
221
|
+
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
|
222
|
+
res = 1
|
|
223
|
+
elif nonlinearity == 'tanh':
|
|
224
|
+
res = 5.0 / 3
|
|
225
|
+
elif nonlinearity == 'relu':
|
|
226
|
+
res = math.sqrt(2.0)
|
|
227
|
+
elif nonlinearity == 'leaky_relu':
|
|
228
|
+
if param is None:
|
|
229
|
+
negative_slope = 0.01
|
|
230
|
+
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
|
231
|
+
# True/False are instances of int, hence check above
|
|
232
|
+
negative_slope = param
|
|
233
|
+
else:
|
|
234
|
+
raise ValueError("For 'HeUniform', 'negative_slope' {} is not a valid number."
|
|
235
|
+
"When 'nonlinearity' has been set to "
|
|
236
|
+
"'leaky_relu', 'negative_slope' should be int or float type, but got "
|
|
237
|
+
"{}.".format(param, type(param)))
|
|
238
|
+
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError("For 'HeUniform', the argument 'nonlinearity' should be one of "
|
|
241
|
+
"['sigmoid', 'tanh', 'relu' or 'leaky_relu'], "
|
|
242
|
+
"but got {}.".format(nonlinearity))
|
|
243
|
+
return res
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _calculate_in_and_out(arr):
|
|
247
|
+
"""
|
|
248
|
+
Calculate n_in and n_out.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
arr (Array): Input array.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
|
255
|
+
"""
|
|
256
|
+
dim = len(arr.shape)
|
|
257
|
+
if dim < 2:
|
|
258
|
+
raise ValueError("If initialize data with xavier uniform, the dimension of data must be greater than 1, "
|
|
259
|
+
"but got {}.".format(dim))
|
|
260
|
+
|
|
261
|
+
n_in = arr.shape[1]
|
|
262
|
+
n_out = arr.shape[0]
|
|
263
|
+
|
|
264
|
+
if dim > 2:
|
|
265
|
+
counter = reduce(lambda x, y: x * y, arr.shape[2:])
|
|
266
|
+
n_in *= counter
|
|
267
|
+
n_out *= counter
|
|
268
|
+
return n_in, n_out
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@_register('xavier_normal')
|
|
272
|
+
class XavierNormal(Initializer):
|
|
273
|
+
r"""
|
|
274
|
+
Generates an array with values sampled from Xavier normal distribution
|
|
275
|
+
:math:`{N}(0, \text{sigma}^2)` in order to initialize a tensor, where
|
|
276
|
+
|
|
277
|
+
.. math::
|
|
278
|
+
sigma = gain * \sqrt{\frac{2}{n_{in} + n_{out}}}
|
|
279
|
+
|
|
280
|
+
where :math:`gain` is an optional scaling factor, :math:`n_{in}` is the number of input units in the weight tensor,
|
|
281
|
+
:math:`n_{out}` is the number of output units in the weight tensor.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
gain (float): An optional scaling factor. Default: ``1`` .
|
|
285
|
+
|
|
286
|
+
Examples:
|
|
287
|
+
>>> import mindspore
|
|
288
|
+
>>> from mindspore.common.initializer import initializer, XavierNormal
|
|
289
|
+
>>> from mindspore import Parameter
|
|
290
|
+
>>> w1 = Parameter(initializer(XavierNormal(), [1, 2, 3], mindspore.float32))
|
|
291
|
+
>>> w2 = Parameter(initializer('xavier_normal', [1, 2, 3], mindspore.float32))
|
|
292
|
+
"""
|
|
293
|
+
def __init__(self, gain=1):
|
|
294
|
+
super().__init__(gain=gain)
|
|
295
|
+
self.gain = gain
|
|
296
|
+
|
|
297
|
+
def _initialize(self, arr):
|
|
298
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape)
|
|
299
|
+
|
|
300
|
+
std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
|
301
|
+
data = _init_random_normal(0, std, arr.shape)
|
|
302
|
+
|
|
303
|
+
_assignment(arr, data)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@_register('xavier_uniform')
|
|
307
|
+
class XavierUniform(Initializer):
|
|
308
|
+
r"""
|
|
309
|
+
Generates an array with values sampled from Xavier uniform distribution
|
|
310
|
+
:math:`{U}(-\text{boundary}, \text{boundary})` in order to initialize a tensor, where
|
|
311
|
+
|
|
312
|
+
.. math::
|
|
313
|
+
boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}
|
|
314
|
+
|
|
315
|
+
where :math:`gain` is an optional scaling factor. :math:`n_{in}` is the number of input units in the weight tensor,
|
|
316
|
+
:math:`n_{out}` is the number of output units in the weight tensor.
|
|
317
|
+
|
|
318
|
+
For details of XavierUniform algorithm, please check
|
|
319
|
+
`<http://proceedings.mlr.press/v9/glorot10a.html>`_.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
gain (float): An optional scaling factor. Default: ``1`` .
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
Examples:
|
|
326
|
+
>>> import mindspore
|
|
327
|
+
>>> from mindspore.common.initializer import initializer, XavierUniform
|
|
328
|
+
>>> from mindspore import Parameter
|
|
329
|
+
>>> w1 = Parameter(initializer(XavierUniform(), [1, 2, 3], mindspore.float32))
|
|
330
|
+
>>> w2 = Parameter(initializer('xavier_uniform', [1, 2, 3], mindspore.float32))
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self, gain=1):
|
|
334
|
+
super(XavierUniform, self).__init__(gain=gain)
|
|
335
|
+
self.gain = gain
|
|
336
|
+
|
|
337
|
+
def _initialize(self, arr):
|
|
338
|
+
n_in, n_out = _calculate_fan_in_and_fan_out(arr.shape)
|
|
339
|
+
boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
|
|
340
|
+
data = _init_random_uniform(-boundary, boundary, arr.shape)
|
|
341
|
+
_assignment(arr, data)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@_register('he_uniform')
|
|
345
|
+
class HeUniform(Initializer):
|
|
346
|
+
r"""
|
|
347
|
+
Generates an array with values sampled from HeKaiming Uniform distribution
|
|
348
|
+
:math:`{U}(-\text{boundary}, \text{boundary})` in order to initialize a tensor, where
|
|
349
|
+
|
|
350
|
+
.. math::
|
|
351
|
+
boundary = \text{gain} \times \sqrt{\frac{3}{fan\_mode}}
|
|
352
|
+
|
|
353
|
+
where :math:`gain` is an optional scaling factor. If :math:`fan\_mode` is ``'fan_in'``,
|
|
354
|
+
it is the number of input units of the weight tensor. If :math:`fan\_mode` is ``'fan_out'``,
|
|
355
|
+
it is the number of output units of the weight tensor.
|
|
356
|
+
|
|
357
|
+
For details of HeUniform algorithm, please check
|
|
358
|
+
`<https://arxiv.org/abs/1502.01852>`_.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
negative_slope (int, float, bool): The negative slope of the rectifier used after this layer
|
|
362
|
+
(only used when `nonlinearity` is 'leaky_relu'). Default: ``0`` .
|
|
363
|
+
mode (str): Either ``'fan_in'`` or ``'fan_out'`` . Choosing ``'fan_in'`` preserves the magnitude of the
|
|
364
|
+
variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes
|
|
365
|
+
in the backwards pass. Default: ``'fan_in'`` .
|
|
366
|
+
nonlinearity (str): The non-linear function, recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
|
|
367
|
+
Default: ``'leaky_relu'`` .
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
Examples:
|
|
371
|
+
>>> import mindspore
|
|
372
|
+
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
373
|
+
>>> from mindspore import Parameter
|
|
374
|
+
>>> w1 = Parameter(initializer(HeUniform(), [1, 2, 3], mindspore.float32))
|
|
375
|
+
>>> w2 = Parameter(initializer('he_uniform', [1, 2, 3], mindspore.float32))
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
379
|
+
super(HeUniform, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
|
|
380
|
+
self.negative_slope = negative_slope
|
|
381
|
+
self.mode = mode
|
|
382
|
+
self.nonlinearity = nonlinearity
|
|
383
|
+
|
|
384
|
+
def _initialize(self, arr):
|
|
385
|
+
fan = _calculate_correct_fan(arr.shape, self.mode)
|
|
386
|
+
gain = _calculate_gain(self.nonlinearity, self.negative_slope)
|
|
387
|
+
std = gain / math.sqrt(fan)
|
|
388
|
+
boundary = math.sqrt(3.0) * std
|
|
389
|
+
data = _init_random_uniform(-boundary, boundary, arr.shape)
|
|
390
|
+
_assignment(arr, data)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@_register('he_normal')
|
|
394
|
+
class HeNormal(Initializer):
|
|
395
|
+
r"""
|
|
396
|
+
Generates an array with values sampled from HeKaiming Normal distribution
|
|
397
|
+
:math:`{N}(0, \text{sigma}^2)` in order to initialize a tensor, where
|
|
398
|
+
|
|
399
|
+
.. math::
|
|
400
|
+
sigma = \frac{gain} {\sqrt{fan\_mode}}
|
|
401
|
+
|
|
402
|
+
where :math:`gain` is an optional scaling factor. :math:`fan\_mode` is the number of input or output units of
|
|
403
|
+
the weight tensor, depending on the `mode` is 'fan_in' or 'fan_out'.
|
|
404
|
+
|
|
405
|
+
For details of HeNormal algorithm, please check `<https://arxiv.org/abs/1502.01852>`_.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
negative_slope (int, float): The negative slope of the rectifier used after this layer
|
|
409
|
+
(only used when `nonlinearity` is 'leaky_relu'). Default: ``0`` .
|
|
410
|
+
mode (str): Either ``'fan_in'`` or ``'fan_out'`` . Choosing ``'fan_in'`` preserves the magnitude of the
|
|
411
|
+
variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes
|
|
412
|
+
in the backwards pass. Default: ``'fan_in'`` .
|
|
413
|
+
nonlinearity (str): The non-linear function, recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
|
|
414
|
+
Default: ``'leaky_relu'`` .
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
Examples:
|
|
418
|
+
>>> import mindspore
|
|
419
|
+
>>> from mindspore.common.initializer import initializer, HeNormal
|
|
420
|
+
>>> from mindspore import Parameter
|
|
421
|
+
>>> w1 = Parameter(initializer(HeNormal(), [1, 2, 3], mindspore.float32))
|
|
422
|
+
>>> w2 = Parameter(initializer('he_normal', [1, 2, 3], mindspore.float32))
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|
426
|
+
super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
|
|
427
|
+
self.negative_slope = negative_slope
|
|
428
|
+
self.mode = mode
|
|
429
|
+
self.nonlinearity = nonlinearity
|
|
430
|
+
|
|
431
|
+
def _initialize(self, arr):
|
|
432
|
+
fan = _calculate_correct_fan(arr.shape, self.mode)
|
|
433
|
+
gain = _calculate_gain(self.nonlinearity, self.negative_slope)
|
|
434
|
+
std = gain / math.sqrt(fan)
|
|
435
|
+
data = _init_random_normal(0, std, arr.shape)
|
|
436
|
+
_assignment(arr, data)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class Constant(Initializer):
|
|
440
|
+
"""
|
|
441
|
+
Generates an array with constant value in order to initialize a tensor.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
value (Union[int, numpy.ndarray]): The value to initialize.
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
Examples:
|
|
448
|
+
>>> import mindspore
|
|
449
|
+
>>> from mindspore.common.initializer import initializer, Constant
|
|
450
|
+
>>> from mindspore import Parameter
|
|
451
|
+
>>> w1 = Parameter(initializer(Constant(3), [1, 2, 3], mindspore.float32))
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
def __init__(self, value):
|
|
455
|
+
super(Constant, self).__init__(value=value)
|
|
456
|
+
self.value = value
|
|
457
|
+
|
|
458
|
+
def _initialize(self, arr):
|
|
459
|
+
arr.fill(self.value)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@_register()
|
|
463
|
+
class Identity(Initializer):
|
|
464
|
+
"""
|
|
465
|
+
Generates a 2 dimension identity matrix array in order to initialize a tensor.
|
|
466
|
+
|
|
467
|
+
Raises:
|
|
468
|
+
ValueError: If the dimension of input tensor is not equal to 2.
|
|
469
|
+
|
|
470
|
+
Examples:
|
|
471
|
+
>>> import mindspore
|
|
472
|
+
>>> from mindspore.common.initializer import initializer, Identity
|
|
473
|
+
>>> from mindspore import Parameter
|
|
474
|
+
>>> w1 = initializer(Identity(), [2, 3], mindspore.float32)
|
|
475
|
+
>>> w2 = initializer('identity', [2, 3], mindspore.float32)
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
def _initialize(self, arr):
|
|
479
|
+
if len(arr.shape) != 2:
|
|
480
|
+
raise ValueError('For Identity initializer, the dimension of the initialized tensor should be 2, '
|
|
481
|
+
'but got {}.'.format(len(arr.shape)))
|
|
482
|
+
value = np.eye(arr.shape[0], arr.shape[1])
|
|
483
|
+
_assignment(arr, value)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
@_register()
|
|
487
|
+
class Sparse(Initializer):
|
|
488
|
+
"""
|
|
489
|
+
Generates a 2 dimension sparse matrix array in order to initialize a tensor. The non-zero positions
|
|
490
|
+
will be filled with the value sampled from the normal distribution :math:`{N}(0, sigma)`.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
sparsity (float): The fraction of elements being set to zero in each column.
|
|
494
|
+
sigma (float): The standard deviation of the normal distribution. Default: ``0.01`` .
|
|
495
|
+
|
|
496
|
+
Raises:
|
|
497
|
+
ValueError: If the dimension of input tensor is not equal to 2.
|
|
498
|
+
|
|
499
|
+
Examples:
|
|
500
|
+
>>> import mindspore
|
|
501
|
+
>>> from mindspore.common.initializer import initializer, Sparse
|
|
502
|
+
>>> from mindspore import Parameter
|
|
503
|
+
>>> w1 = Parameter(initializer(Sparse(sparsity=0.1, sigma=0.01), [5, 8], mindspore.float32))
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
def __init__(self, sparsity, sigma=0.01):
|
|
507
|
+
super(Sparse, self).__init__()
|
|
508
|
+
self.sparsity = sparsity
|
|
509
|
+
self.sigma = sigma
|
|
510
|
+
|
|
511
|
+
def _initialize(self, arr):
|
|
512
|
+
if len(arr.shape) != 2:
|
|
513
|
+
raise ValueError('For Sparse initializer, the dimension of the initialized tensor should be 2, '
|
|
514
|
+
'but got {}.'.format(len(arr.shape)))
|
|
515
|
+
rows, cols = arr.shape
|
|
516
|
+
zero_num = int(np.ceil(self.sparsity * rows))
|
|
517
|
+
data = _init_random_normal(0, self.sigma, arr.shape)
|
|
518
|
+
for col_idx in range(cols):
|
|
519
|
+
row_idx = np.random.permutation(list(range(rows)))[: zero_num]
|
|
520
|
+
data[row_idx, col_idx] = 0.
|
|
521
|
+
_assignment(arr, data)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
@_register()
|
|
525
|
+
class Dirac(Initializer):
|
|
526
|
+
"""
|
|
527
|
+
Generates an array with the Dirac delta function in order to initialize a tensor.
|
|
528
|
+
It's usually used in convolution layers, preserves as many identities of the inputs as possible.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
groups (int): The number of groups in convolution layer. Each group applies the same initialization.
|
|
532
|
+
Default: ``1`` .
|
|
533
|
+
|
|
534
|
+
Raises:
|
|
535
|
+
ValueError: If the dimension of the initialized tensor is not in [3, 4, 5].
|
|
536
|
+
ValueError: The first dimension of the initialized tensor cannot be divisible by group.
|
|
537
|
+
|
|
538
|
+
Examples:
|
|
539
|
+
>>> import mindspore
|
|
540
|
+
>>> from mindspore.common.initializer import initializer, Dirac
|
|
541
|
+
>>> from mindspore import Parameter
|
|
542
|
+
>>> w1 = Parameter(initializer(Dirac(groups=2), [6, 4, 3, 3], mindspore.float32))
|
|
543
|
+
>>> w2 = Parameter(initializer("dirac", [6, 4, 3, 3], mindspore.float32))
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
def __init__(self, groups=1):
|
|
547
|
+
super(Dirac, self).__init__()
|
|
548
|
+
self.groups = groups
|
|
549
|
+
|
|
550
|
+
def _initialize(self, arr):
|
|
551
|
+
dimension = len(arr.shape)
|
|
552
|
+
data = np.zeros(arr.shape)
|
|
553
|
+
if dimension not in [3, 4, 5]:
|
|
554
|
+
raise ValueError("For Dirac initializer, only support "
|
|
555
|
+
"to initialize tensor with dimension of 3, 4 or 5, but got {}.".format(dimension))
|
|
556
|
+
|
|
557
|
+
shapes = arr.shape
|
|
558
|
+
if shapes[0] % self.groups != 0:
|
|
559
|
+
raise ValueError("For Dirac initializer, the first dimension of"
|
|
560
|
+
"the initialized tensor must be divisible by groups, "
|
|
561
|
+
"but got first dimension{}, groups{}.".format(shapes[0], self.groups))
|
|
562
|
+
|
|
563
|
+
out_channel_per_group = shapes[0] // self.groups
|
|
564
|
+
min_dim = min(out_channel_per_group, shapes[1])
|
|
565
|
+
|
|
566
|
+
for group in range(self.groups):
|
|
567
|
+
for dim in range(min_dim):
|
|
568
|
+
if dimension == 3:
|
|
569
|
+
data[group * out_channel_per_group + dim, dim, shapes[2] // 2] = 1
|
|
570
|
+
elif dimension == 4:
|
|
571
|
+
data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2] = 1
|
|
572
|
+
else:
|
|
573
|
+
data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2, shapes[4] // 2] = 1
|
|
574
|
+
_assignment(arr, data)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
@_register()
|
|
578
|
+
class Orthogonal(Initializer):
|
|
579
|
+
r"""
|
|
580
|
+
Generates a (semi) orthogonal matrix array in order to initialize a tensor.
|
|
581
|
+
The dimension of input tensor must have at least 2 dimensions.
|
|
582
|
+
If the dimension is greater than 2, the trailing dimensions will be flattened.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
gain (float): An optional scaling factor. Default: ``1.0`` .
|
|
586
|
+
|
|
587
|
+
Raises:
|
|
588
|
+
ValueError: If the dimension of input tensor is less than 2.
|
|
589
|
+
|
|
590
|
+
Examples:
|
|
591
|
+
>>> import mindspore
|
|
592
|
+
>>> from mindspore.common.initializer import initializer, Orthogonal
|
|
593
|
+
>>> from mindspore import Parameter
|
|
594
|
+
>>> w1 = Parameter(initializer(Orthogonal(gain=2.), [2, 3, 4], mindspore.float32))
|
|
595
|
+
>>> w2 = Parameter(initializer('orthogonal', [2, 3, 4], mindspore.float32))
|
|
596
|
+
"""
|
|
597
|
+
|
|
598
|
+
def __init__(self, gain=1.):
|
|
599
|
+
super(Orthogonal, self).__init__(gain=gain)
|
|
600
|
+
self.gain = gain
|
|
601
|
+
|
|
602
|
+
def _initialize(self, arr):
|
|
603
|
+
if len(arr.shape) < 2:
|
|
604
|
+
raise ValueError('For Orthogonal initializer, the dimension of the initialized tensor should'
|
|
605
|
+
' be no less than 2, but got {}.'.format(len(arr.shape)))
|
|
606
|
+
rows = arr.shape[0]
|
|
607
|
+
|
|
608
|
+
cols = np.prod(arr.shape) // rows
|
|
609
|
+
data = _init_random_normal(0, 1, (rows, cols))
|
|
610
|
+
|
|
611
|
+
if rows < cols:
|
|
612
|
+
data = data.T
|
|
613
|
+
|
|
614
|
+
q, r = np.linalg.qr(data)
|
|
615
|
+
d = np.diag(r)
|
|
616
|
+
ph = np.sign(d)
|
|
617
|
+
q *= ph
|
|
618
|
+
|
|
619
|
+
if rows < cols:
|
|
620
|
+
q = q.T
|
|
621
|
+
q = q * self.gain
|
|
622
|
+
_assignment(arr, q.reshape(arr.shape))
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@_register()
|
|
626
|
+
class VarianceScaling(Initializer):
|
|
627
|
+
r"""
|
|
628
|
+
Generates an random array with scaling in order to initialize a tensor.
|
|
629
|
+
When `distribution` is 'truncated_normal' or 'untruncated_normal', the value will be sampled from truncated or
|
|
630
|
+
untruncated normal distribution with a mean of 0 and a scaled standard deviation
|
|
631
|
+
:math:`stddev = \sqrt{\frac{scale}{n}}`. :math:`n` will be the number of input units if `mode` is ``'fan_in'``,
|
|
632
|
+
while :math:`n` will be
|
|
633
|
+
the number of output units if `mode` is ``'fan_out'``. :math:`n` will be the average of ``'fan_in'``
|
|
634
|
+
and ``'fan_out'`` if `mode` is ``'fan_avg'``.
|
|
635
|
+
When `distribution` is ``'uniform'``, the value will be sampled from a uniform distribution within the limit of
|
|
636
|
+
:math:`[-\sqrt{\frac{3*scale}{n}}, \sqrt{\frac{3*scale}{n}}]`.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
scale (float): The scaling factor. Default: ``1.0`` .
|
|
640
|
+
mode (str): Should be ``'fan_in'`` , ``'fan_out'`` or ``'fan_avg'`` . Default: ``'fan_in'`` .
|
|
641
|
+
distribution(str): The type of distribution chose to sample values. It should be
|
|
642
|
+
``'uniform'`` , ``'truncated_normal'`` or ``'untruncated_normal'`` . Default: ``'truncated_normal'`` .
|
|
643
|
+
|
|
644
|
+
Raises:
|
|
645
|
+
ValueError: If `scale` is not greater than 0.
|
|
646
|
+
ValueError: If `mode` is not ``'fan_in'``, ``'fan_out'`` or ``'fan_avg'``.
|
|
647
|
+
ValueError: If `distribution` is not ``'uniform'``, ``'truncated_normal'`` or ``'untruncated_normal'``.
|
|
648
|
+
|
|
649
|
+
Examples:
|
|
650
|
+
>>> import mindspore
|
|
651
|
+
>>> from mindspore.common.initializer import initializer, VarianceScaling
|
|
652
|
+
>>> from mindspore import Parameter
|
|
653
|
+
>>> w1 = Parameter(initializer(VarianceScaling(scale=1.0, mode='fan_out',
|
|
654
|
+
... distribution='untruncated_normal'), [2, 3], mindspore.float32))
|
|
655
|
+
>>> w2 = Parameter(initializer('varianceScaling', [2, 3], mindspore.float32))
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal'):
|
|
659
|
+
super(VarianceScaling, self).__init__(scale=scale, mode=mode, distribution=distribution)
|
|
660
|
+
if scale <= 0.:
|
|
661
|
+
raise ValueError("For VarianceScaling initializer, "
|
|
662
|
+
"the argument 'scale' must be greater than 0, but got {}.".format(scale))
|
|
663
|
+
|
|
664
|
+
if mode not in ['fan_in', 'fan_out', 'fan_avg']:
|
|
665
|
+
raise ValueError("For VarianceScaling initializer, the argument 'mode' must be fan_in, "
|
|
666
|
+
"fan_out or fan_avg, but got {}.".format(mode))
|
|
667
|
+
|
|
668
|
+
if distribution not in ['uniform', 'truncated_normal', 'untruncated_normal']:
|
|
669
|
+
raise ValueError("For VarianceScaling initializer, the argument 'distribution' must be uniform, "
|
|
670
|
+
"truncated_norm or untruncated_norm, but got {}.".format(distribution))
|
|
671
|
+
|
|
672
|
+
self.scale = scale
|
|
673
|
+
self.mode = mode
|
|
674
|
+
self.distribution = distribution
|
|
675
|
+
|
|
676
|
+
def _initialize(self, arr):
|
|
677
|
+
scale = self.scale
|
|
678
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape)
|
|
679
|
+
if self.mode == 'fan_in':
|
|
680
|
+
scale /= max(1., fan_in)
|
|
681
|
+
elif self.mode == 'fan_out':
|
|
682
|
+
scale /= max(1., fan_out)
|
|
683
|
+
else:
|
|
684
|
+
scale /= max(1., (fan_in + fan_out) / 2.)
|
|
685
|
+
|
|
686
|
+
if self.distribution == 'truncated_norm':
|
|
687
|
+
stddev = np.sqrt(scale) / 0.87962566103423978
|
|
688
|
+
data = _init_truncated_normal(-2, 2, 0, stddev, arr.shape)
|
|
689
|
+
elif self.distribution == 'untruncated_normal':
|
|
690
|
+
stddev = np.sqrt(scale)
|
|
691
|
+
data = _init_random_normal(0, stddev, arr.shape)
|
|
692
|
+
else:
|
|
693
|
+
limit = np.sqrt(3.0 * scale)
|
|
694
|
+
data = _init_random_uniform(-limit, limit, arr.shape)
|
|
695
|
+
_assignment(arr, data)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
@_register()
|
|
699
|
+
class Uniform(Initializer):
|
|
700
|
+
r"""
|
|
701
|
+
Generates an array with values sampled from Uniform distribution :math:`{U}(-\text{scale}, \text{scale})` in order
|
|
702
|
+
to initialize a tensor.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
scale (float): The bound of the Uniform distribution. Default: ``0.07`` .
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
Examples:
|
|
709
|
+
>>> import mindspore
|
|
710
|
+
>>> from mindspore.common.initializer import initializer, Uniform
|
|
711
|
+
>>> from mindspore import Parameter
|
|
712
|
+
>>> w1 = Parameter(initializer(Uniform(), [1, 2, 3], mindspore.float32))
|
|
713
|
+
>>> w2 = Parameter(initializer('uniform', [1, 2, 3], mindspore.float32))
|
|
714
|
+
"""
|
|
715
|
+
|
|
716
|
+
def __init__(self, scale=0.07):
|
|
717
|
+
super(Uniform, self).__init__(scale=scale)
|
|
718
|
+
self.scale = scale
|
|
719
|
+
|
|
720
|
+
def _initialize(self, arr):
|
|
721
|
+
tmp = _init_random_uniform(-self.scale, self.scale, arr.shape)
|
|
722
|
+
_assignment(arr, tmp)
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
@_register()
|
|
726
|
+
class Normal(Initializer):
|
|
727
|
+
r"""
|
|
728
|
+
Generates an array with values sampled from Normal distribution :math:`{N}(\text{sigma}, \text{mean})` in order to
|
|
729
|
+
initialize a tensor.
|
|
730
|
+
|
|
731
|
+
.. math::
|
|
732
|
+
f(x) = \frac{1} {\sqrt{2*π} * sigma}exp(-\frac{(x - mean)^2} {2*{sigma}^2})
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
sigma (float): The standard deviation of Normal distribution. Default: ``0.01`` .
|
|
736
|
+
mean (float): The mean of Normal distribution. Default: ``0.0`` .
|
|
737
|
+
|
|
738
|
+
Examples:
|
|
739
|
+
>>> import mindspore
|
|
740
|
+
>>> from mindspore.common.initializer import initializer, Normal
|
|
741
|
+
>>> from mindspore import Parameter
|
|
742
|
+
>>> w1 = Parameter(initializer(Normal(), [1, 2, 3], mindspore.float32))
|
|
743
|
+
>>> w2 = Parameter(initializer('normal', [1, 2, 3], mindspore.float32))
|
|
744
|
+
"""
|
|
745
|
+
|
|
746
|
+
def __init__(self, sigma=0.01, mean=0.0):
|
|
747
|
+
super(Normal, self).__init__(sigma=sigma, mean=mean)
|
|
748
|
+
self.sigma = sigma
|
|
749
|
+
self.mean = mean
|
|
750
|
+
|
|
751
|
+
def _initialize(self, arr):
|
|
752
|
+
data = _init_random_normal(self.mean, self.sigma, arr.shape)
|
|
753
|
+
_assignment(arr, data)
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@_register()
|
|
757
|
+
class TruncatedNormal(Initializer):
|
|
758
|
+
r"""
|
|
759
|
+
Generates an array with values sampled from Truncated Normal distribution in order to initialize a tensor.
|
|
760
|
+
|
|
761
|
+
Args:
|
|
762
|
+
sigma (float): The standard deviation of Truncated Normal distribution. Default: ``0.01`` .
|
|
763
|
+
mean (float): The mean of Truncated Normal distribution. Default: ``0.0`` .
|
|
764
|
+
a (float): The lower bound of the truncated interval. Default: ``-2.0`` .
|
|
765
|
+
b (float): The upper bound of the truncated interval. Default: ``2.0`` .
|
|
766
|
+
|
|
767
|
+
Examples:
|
|
768
|
+
>>> import mindspore
|
|
769
|
+
>>> from mindspore.common.initializer import initializer, TruncatedNormal
|
|
770
|
+
>>> from mindspore import Parameter
|
|
771
|
+
>>> w1 = Parameter(initializer(TruncatedNormal(), [1, 2, 3], mindspore.float32))
|
|
772
|
+
>>> w2 = Parameter(initializer('truncatedNormal', [1, 2, 3], mindspore.float32))
|
|
773
|
+
"""
|
|
774
|
+
|
|
775
|
+
def __init__(self, sigma=0.01, mean=0.0, a=-2.0, b=2.0):
|
|
776
|
+
super(TruncatedNormal, self).__init__(sigma=sigma, mean=mean, a=a, b=b)
|
|
777
|
+
self.sigma = sigma
|
|
778
|
+
self.mean = mean
|
|
779
|
+
self.a = a
|
|
780
|
+
self.b = b
|
|
781
|
+
|
|
782
|
+
def _initialize(self, arr):
|
|
783
|
+
tmp = _init_truncated_normal(self.a, self.b, self.mean, self.sigma, arr.shape)
|
|
784
|
+
_assignment(arr, tmp)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def initializer(init, shape=None, dtype=mstype.float32):
|
|
788
|
+
"""
|
|
789
|
+
Create and initialize a tensor.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
|
|
793
|
+
|
|
794
|
+
- `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
|
|
795
|
+
class will be called in practice. The value of `init` can be ``"normal"``, ``"ones"`` or
|
|
796
|
+
``"zeros"``, etc.
|
|
797
|
+
|
|
798
|
+
- `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
|
|
799
|
+
|
|
800
|
+
- `numbers.Number`: The `Constant` will be called to initialize tensor.
|
|
801
|
+
|
|
802
|
+
- `Tensor`: The tensor will be called to initialize tensor.
|
|
803
|
+
|
|
804
|
+
shape (Union[tuple, list, int]): The shape of the initialized tensor. Default: ``None`` .
|
|
805
|
+
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: ``mstype.float32`` .
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
Tensor, return is Tensor object.
|
|
809
|
+
|
|
810
|
+
Raises:
|
|
811
|
+
TypeError: The type of the argument 'init' is not correct.
|
|
812
|
+
ValueError: The shape of the tensor which is passed through 'init' is not the same as that passed by 'shape'.
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
Examples:
|
|
816
|
+
>>> import numpy as np
|
|
817
|
+
>>> import mindspore
|
|
818
|
+
>>> from mindspore import Tensor
|
|
819
|
+
>>> from mindspore.common.initializer import initializer, One
|
|
820
|
+
>>> from mindspore import Parameter
|
|
821
|
+
>>> data = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
|
|
822
|
+
>>> w1 = Parameter(initializer(data, [1, 2, 3], mindspore.float32))
|
|
823
|
+
>>> w2 = Parameter(initializer('ones', [1, 2, 3], mindspore.float32))
|
|
824
|
+
>>> w3 = Parameter(initializer(One(), [1, 2, 3], mindspore.float32))
|
|
825
|
+
>>> w4 = Parameter(initializer(0, [1, 2, 3], mindspore.float32))
|
|
826
|
+
"""
|
|
827
|
+
if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
|
|
828
|
+
raise TypeError("For 'initializer', the type of the 'init' argument should be 'Tensor', 'number', 'string' "
|
|
829
|
+
"or 'initializer', but got {}.".format(type(init)))
|
|
830
|
+
|
|
831
|
+
if isinstance(init, Tensor):
|
|
832
|
+
init_shape = init.shape
|
|
833
|
+
shape = shape if isinstance(shape, (tuple, list)) else [shape]
|
|
834
|
+
if shape is not None and init_shape != tuple(shape):
|
|
835
|
+
raise ValueError("For 'initializer', the shape of the 'init' argument should be same as "
|
|
836
|
+
"the argument 'shape', but got the "
|
|
837
|
+
"'init' shape {} and the 'shape' {}.".format(list(init.shape), shape))
|
|
838
|
+
return init
|
|
839
|
+
|
|
840
|
+
if isinstance(shape, list):
|
|
841
|
+
shape = tuple(shape)
|
|
842
|
+
elif isinstance(shape, numbers.Number):
|
|
843
|
+
shape = (shape,)
|
|
844
|
+
|
|
845
|
+
for value in shape if shape is not None else ():
|
|
846
|
+
if not isinstance(value, int) or value <= 0:
|
|
847
|
+
raise ValueError(f"For 'initializer', the argument 'shape' is invalid, the value of 'shape' "
|
|
848
|
+
f"must be positive integer, "
|
|
849
|
+
f"but got {shape}")
|
|
850
|
+
|
|
851
|
+
if isinstance(init, str):
|
|
852
|
+
class_name = _INITIALIZER_ALIAS.get(init.lower())
|
|
853
|
+
if class_name is None:
|
|
854
|
+
raise ValueError(f"For 'initializer', the class corresponding to '{init}' was not found.")
|
|
855
|
+
init = class_name()
|
|
856
|
+
elif isinstance(init, numbers.Number):
|
|
857
|
+
init = Constant(init)
|
|
858
|
+
shape = shape if shape is not None else init.shape
|
|
859
|
+
init_obj = Tensor(dtype=dtype, shape=shape, init=init)
|
|
860
|
+
return init_obj
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
__all__ = [
|
|
864
|
+
'Initializer',
|
|
865
|
+
'initializer',
|
|
866
|
+
'TruncatedNormal',
|
|
867
|
+
'Normal',
|
|
868
|
+
'Uniform',
|
|
869
|
+
'HeUniform',
|
|
870
|
+
'HeNormal',
|
|
871
|
+
'XavierUniform',
|
|
872
|
+
'XavierNormal',
|
|
873
|
+
'One',
|
|
874
|
+
'Zero',
|
|
875
|
+
'Constant',
|
|
876
|
+
'Identity',
|
|
877
|
+
'Sparse',
|
|
878
|
+
'Dirac',
|
|
879
|
+
'Orthogonal',
|
|
880
|
+
'VarianceScaling']
|