mindspore 2.4.0__cp310-cp310-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
- mindspore/_c_expression.cpython-310-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,968 @@
|
|
|
1
|
+
|
|
2
|
+
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ============================================================================
|
|
16
|
+
"""Cell_wrapper."""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
from __future__ import division
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from types import FunctionType, MethodType
|
|
22
|
+
|
|
23
|
+
from mindspore import log as logger
|
|
24
|
+
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
|
|
25
|
+
_get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
|
|
26
|
+
from mindspore.context import ParallelMode
|
|
27
|
+
from mindspore import _checkparam as validator
|
|
28
|
+
from mindspore import ops, nn
|
|
29
|
+
from mindspore.common import dtype as mstype
|
|
30
|
+
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
31
|
+
from mindspore.common.tensor import Tensor
|
|
32
|
+
from mindspore.ops.primitive import _primexpr
|
|
33
|
+
from mindspore.ops import composite as C
|
|
34
|
+
from mindspore.ops import functional as F
|
|
35
|
+
from mindspore.ops import operations as P
|
|
36
|
+
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
|
37
|
+
from mindspore.nn.cell import Cell
|
|
38
|
+
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
39
|
+
from mindspore.utils import ExitByRequest
|
|
40
|
+
|
|
41
|
+
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@_get_datatype.register("Tensor")
|
|
45
|
+
def _tensors_get_datatype(param):
|
|
46
|
+
"""
|
|
47
|
+
Acquire parameter datatype.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
param (Tensor): The parameter before operation.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
mstype, the datatype of parameter.
|
|
54
|
+
"""
|
|
55
|
+
return F.dtype(param)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@_cast_datatype.register("TypeType", "Tensor")
|
|
62
|
+
def _tensors_cast_datatype(datatype, param):
|
|
63
|
+
"""
|
|
64
|
+
Cast gradient to datatype.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
datatype (mstype): the destination datatype of parameter.
|
|
68
|
+
param (Tensor): The parameter before operation.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Tensor, the parameter after operation.
|
|
72
|
+
"""
|
|
73
|
+
return F.cast(param, datatype)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class WithLossCell(Cell):
|
|
77
|
+
r"""
|
|
78
|
+
Cell with loss function.
|
|
79
|
+
|
|
80
|
+
Wraps the network with loss function. This Cell accepts data and label as inputs and
|
|
81
|
+
the computed loss will be returned.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
backbone (Cell): The backbone network to wrap.
|
|
85
|
+
loss_fn (Cell): The loss function used to compute loss.
|
|
86
|
+
|
|
87
|
+
Inputs:
|
|
88
|
+
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
89
|
+
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
90
|
+
|
|
91
|
+
Outputs:
|
|
92
|
+
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
TypeError: If dtype of `data` or `label` is neither float16 nor float32.
|
|
96
|
+
|
|
97
|
+
Supported Platforms:
|
|
98
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> import mindspore as ms
|
|
102
|
+
>>> from mindspore import Tensor, nn
|
|
103
|
+
>>> import numpy as np
|
|
104
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
105
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
106
|
+
>>> net = LeNet5()
|
|
107
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
|
108
|
+
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
|
109
|
+
>>>
|
|
110
|
+
>>> batch_size = 2
|
|
111
|
+
>>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
|
112
|
+
>>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
|
|
113
|
+
>>>
|
|
114
|
+
>>> output_data = net_with_criterion(data, label)
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, backbone, loss_fn):
|
|
118
|
+
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
119
|
+
self._backbone = backbone
|
|
120
|
+
self._loss_fn = loss_fn
|
|
121
|
+
self._get_attr_from_cell(backbone)
|
|
122
|
+
|
|
123
|
+
def construct(self, data, label):
|
|
124
|
+
out = self._backbone(data)
|
|
125
|
+
return self._loss_fn(out, label)
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def backbone_network(self):
|
|
129
|
+
"""
|
|
130
|
+
Get the backbone network.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Cell, the backbone network.
|
|
134
|
+
|
|
135
|
+
Examples:
|
|
136
|
+
>>> from mindspore import nn
|
|
137
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
138
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
139
|
+
>>> net = LeNet5()
|
|
140
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
|
141
|
+
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
|
142
|
+
>>> backbone = net_with_criterion.backbone_network
|
|
143
|
+
"""
|
|
144
|
+
return self._backbone
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class WithGradCell(Cell):
|
|
148
|
+
r"""
|
|
149
|
+
Cell that returns the gradients.
|
|
150
|
+
|
|
151
|
+
Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
|
|
152
|
+
as argument. If loss function in None, the network must be a wrapper of network and loss function. This
|
|
153
|
+
Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter.
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
Run in PyNative mode.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
network (Cell): The target network to wrap. The network only supports single output.
|
|
160
|
+
loss_fn (Cell): Primitive loss function used to compute gradients. Default: ``None`` .
|
|
161
|
+
sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
|
|
162
|
+
must be same as the `network` output. If ``None`` , we will fill one to a same type shape of
|
|
163
|
+
output value. Default: ``None`` .
|
|
164
|
+
|
|
165
|
+
Inputs:
|
|
166
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
167
|
+
|
|
168
|
+
Outputs:
|
|
169
|
+
list, a list of Tensors with identical shapes as trainable weights.
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple.
|
|
173
|
+
|
|
174
|
+
Supported Platforms:
|
|
175
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
176
|
+
|
|
177
|
+
Examples:
|
|
178
|
+
>>> import mindspore as ms
|
|
179
|
+
>>> from mindspore import nn
|
|
180
|
+
>>> # Defined a network without loss function, taking LeNet5 as an example.
|
|
181
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
182
|
+
>>> net = LeNet5()
|
|
183
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
184
|
+
>>> grad_net = nn.WithGradCell(net, loss_fn)
|
|
185
|
+
>>>
|
|
186
|
+
>>> # For a network wrapped with loss function
|
|
187
|
+
>>> net = Net()
|
|
188
|
+
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
|
189
|
+
>>> grad_net = nn.WithGradCell(net_with_criterion)
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, network, loss_fn=None, sens=None):
|
|
193
|
+
super(WithGradCell, self).__init__(auto_prefix=False)
|
|
194
|
+
self.network = network
|
|
195
|
+
self.loss_fn = loss_fn
|
|
196
|
+
self.weights = ParameterTuple(network.trainable_params())
|
|
197
|
+
self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
|
|
198
|
+
self.sens = sens
|
|
199
|
+
if loss_fn is None:
|
|
200
|
+
self.network_with_loss = network
|
|
201
|
+
else:
|
|
202
|
+
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
|
|
203
|
+
self.network_with_loss.set_train()
|
|
204
|
+
self._get_attr_from_cell(network)
|
|
205
|
+
|
|
206
|
+
def construct(self, *inputs):
|
|
207
|
+
weights = self.weights
|
|
208
|
+
if self.sens is None:
|
|
209
|
+
grads = self.grad(self.network_with_loss, weights)(*inputs)
|
|
210
|
+
else:
|
|
211
|
+
grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
|
|
212
|
+
return grads
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ForwardValueAndGrad(Cell):
|
|
216
|
+
r"""
|
|
217
|
+
Encapsulate training network.
|
|
218
|
+
|
|
219
|
+
Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
|
|
220
|
+
The backward graph will be created in the gradient function to calculating gradient.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
network (Union[Cell, Function, MethodType]): The training network.
|
|
224
|
+
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
|
|
225
|
+
Default: ``None`` .
|
|
226
|
+
get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
|
|
227
|
+
get_by_list (bool): If ``True`` s, get all the gradients with respect to Parameter variables.
|
|
228
|
+
If get_all and get_by_list are both ``False`` , get the gradient with respect to first input.
|
|
229
|
+
If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and Parameter
|
|
230
|
+
variables at the same time in the form of ((gradients with respect to inputs),
|
|
231
|
+
(gradients with respect to parameters)). Default: ``False`` .
|
|
232
|
+
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
|
233
|
+
If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
234
|
+
Default: ``False`` .
|
|
235
|
+
If the sens_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred
|
|
236
|
+
through the input parameter.
|
|
237
|
+
|
|
238
|
+
Inputs:
|
|
239
|
+
- **\*inputs** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
|
|
240
|
+
- **sens** - A sensitivity (gradient with respect to output) as the input of backpropagation.
|
|
241
|
+
If network has single output, the sens is a tensor.
|
|
242
|
+
If network has multiple outputs, the sens is the tuple(tensor).
|
|
243
|
+
|
|
244
|
+
Outputs:
|
|
245
|
+
- **forward value** - The result of network forward running.
|
|
246
|
+
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
|
|
247
|
+
|
|
248
|
+
Supported Platforms:
|
|
249
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
250
|
+
|
|
251
|
+
Examples:
|
|
252
|
+
>>> import numpy as np
|
|
253
|
+
>>> import mindspore
|
|
254
|
+
>>> from mindspore import Tensor, nn, ops, ParameterTuple, Parameter
|
|
255
|
+
>>>
|
|
256
|
+
>>> class Net(nn.Cell):
|
|
257
|
+
... def __init__(self):
|
|
258
|
+
... super(Net, self).__init__()
|
|
259
|
+
... self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight")
|
|
260
|
+
... self.matmul = ops.MatMul()
|
|
261
|
+
...
|
|
262
|
+
... def construct(self, x):
|
|
263
|
+
... out = self.matmul(x, self.weight)
|
|
264
|
+
... return out
|
|
265
|
+
...
|
|
266
|
+
>>> net = Net()
|
|
267
|
+
>>> criterion = nn.SoftmaxCrossEntropyWithLogits()
|
|
268
|
+
>>> net_with_criterion = nn.WithLossCell(net, criterion)
|
|
269
|
+
>>> weight = ParameterTuple(net.trainable_params())
|
|
270
|
+
>>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True)
|
|
271
|
+
>>> inputs = Tensor(np.ones([1, 2]).astype(np.float32))
|
|
272
|
+
>>> labels = Tensor(np.ones([1, 2]).astype(np.float32))
|
|
273
|
+
>>> result = train_network(inputs, labels)
|
|
274
|
+
>>> print(result)
|
|
275
|
+
(Tensor(shape=[1], dtype=Float32, value= [ 1.38629436e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value=
|
|
276
|
+
[[ -1.00000000e+00, -1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value=
|
|
277
|
+
[[ 0.00000000e+00, 0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value=
|
|
278
|
+
[[ -5.00000000e-01, -5.00000000e-01],
|
|
279
|
+
[ -5.00000000e-01, -5.00000000e-01]]),)))
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
|
|
283
|
+
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
|
|
284
|
+
if not isinstance(network, (Cell, FunctionType, MethodType)):
|
|
285
|
+
raise TypeError(f"For 'ForwardValueAndGrad', "
|
|
286
|
+
f"the argument 'network' must be cell, function type or method type, "
|
|
287
|
+
f"but got '{type(network)}'")
|
|
288
|
+
if not isinstance(get_all, bool):
|
|
289
|
+
raise TypeError(f"For 'ForwardValueAndGrad', "
|
|
290
|
+
f"the type of 'get_all' must be bool, but got '{type(get_all)}'")
|
|
291
|
+
if not isinstance(get_by_list, bool):
|
|
292
|
+
raise TypeError(f"For 'ForwardValueAndGrad', "
|
|
293
|
+
f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'")
|
|
294
|
+
if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)):
|
|
295
|
+
raise TypeError(f"For 'ForwardValueAndGrad', "
|
|
296
|
+
f"when 'get_by_list' is set to True, the argument 'weights' must be "
|
|
297
|
+
f"Parameters array, but got '{type(weights)}'")
|
|
298
|
+
self.network = network
|
|
299
|
+
if isinstance(network, Cell):
|
|
300
|
+
self.network.set_grad()
|
|
301
|
+
self.weights = weights
|
|
302
|
+
self.get_all = get_all
|
|
303
|
+
self.get_by_list = get_by_list
|
|
304
|
+
self.sens_param = sens_param
|
|
305
|
+
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
|
306
|
+
self._get_attr_from_cell(network)
|
|
307
|
+
|
|
308
|
+
def construct(self, *inputs):
|
|
309
|
+
grad_inputs = inputs
|
|
310
|
+
if self.sens_param:
|
|
311
|
+
inputs = inputs[:-1]
|
|
312
|
+
loss = self.network(*inputs)
|
|
313
|
+
if self.get_by_list:
|
|
314
|
+
grads = self.grad(self.network, self.weights)(*grad_inputs)
|
|
315
|
+
else:
|
|
316
|
+
grads = self.grad(self.network)(*grad_inputs)
|
|
317
|
+
return loss, grads
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class TrainOneStepCell(Cell):
|
|
321
|
+
r"""
|
|
322
|
+
Network training package class.
|
|
323
|
+
|
|
324
|
+
Wraps the `network` with the `optimizer`. The resulting Cell is trained with input '\*inputs'.
|
|
325
|
+
The backward graph will be created in the construct function to update the parameter. Different
|
|
326
|
+
parallel modes are available for training.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
network (Cell): The training network. The network only supports single output.
|
|
330
|
+
optimizer (Union[Cell]): Optimizer for updating the network parameters.
|
|
331
|
+
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is
|
|
332
|
+
``None`` , which is ``1.0`` .
|
|
333
|
+
return_grad (bool): Whether to return gradient. If ``True``, it will return the gradient in the form of a dict
|
|
334
|
+
while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value
|
|
335
|
+
is the gradient value. Default value is ``False`` .
|
|
336
|
+
|
|
337
|
+
Inputs:
|
|
338
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
339
|
+
|
|
340
|
+
Outputs:
|
|
341
|
+
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
TypeError: If `sens` is not a numbers.Number.
|
|
345
|
+
|
|
346
|
+
Supported Platforms:
|
|
347
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
348
|
+
|
|
349
|
+
Examples:
|
|
350
|
+
>>> import mindspore.nn as nn
|
|
351
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
352
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
353
|
+
>>> net = LeNet5()
|
|
354
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
355
|
+
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
356
|
+
>>> #1) Using the WithLossCell provided by MindSpore
|
|
357
|
+
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
358
|
+
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
359
|
+
>>>
|
|
360
|
+
>>> #2) Using user-defined WithLossCell
|
|
361
|
+
>>> class MyWithLossCell(nn.Cell):
|
|
362
|
+
... def __init__(self, backbone, loss_fn):
|
|
363
|
+
... super(MyWithLossCell, self).__init__(auto_prefix=False)
|
|
364
|
+
... self._backbone = backbone
|
|
365
|
+
... self._loss_fn = loss_fn
|
|
366
|
+
...
|
|
367
|
+
... def construct(self, x, y, label):
|
|
368
|
+
... out = self._backbone(x, y)
|
|
369
|
+
... return self._loss_fn(out, label)
|
|
370
|
+
...
|
|
371
|
+
... @property
|
|
372
|
+
... def backbone_network(self):
|
|
373
|
+
... return self._backbone
|
|
374
|
+
...
|
|
375
|
+
>>> loss_net = MyWithLossCell(net, loss_fn)
|
|
376
|
+
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, network, optimizer, sens=None, return_grad=False):
|
|
380
|
+
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
381
|
+
self.network = network
|
|
382
|
+
self.network.set_grad()
|
|
383
|
+
self.optimizer = optimizer
|
|
384
|
+
self.weights = self.optimizer.parameters
|
|
385
|
+
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
386
|
+
self.grad_no_sens = C.GradOperation(get_by_list=True)
|
|
387
|
+
self.sens = sens
|
|
388
|
+
if self.sens == 0:
|
|
389
|
+
raise ValueError("The input argument of 'sens' can not be 0.")
|
|
390
|
+
self.sense_flag = True
|
|
391
|
+
if self.sens is None:
|
|
392
|
+
self.sense_flag = False
|
|
393
|
+
self.sens = 1.0
|
|
394
|
+
self.return_grad = return_grad
|
|
395
|
+
if return_grad:
|
|
396
|
+
self.weights_name = [i.name for i in self.optimizer.parameters]
|
|
397
|
+
self.reducer_flag = False
|
|
398
|
+
self.grad_reducer = nn.Identity()
|
|
399
|
+
self.parallel_mode = _get_parallel_mode()
|
|
400
|
+
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
|
|
401
|
+
_is_pynative_parallel()
|
|
402
|
+
if self.reducer_flag:
|
|
403
|
+
self.mean = _get_gradients_mean()
|
|
404
|
+
self.degree = _get_device_num()
|
|
405
|
+
from mindspore.communication.management import GlobalComm
|
|
406
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
407
|
+
if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
|
|
408
|
+
from mindspore.communication.management import get_group_size, create_group, get_rank
|
|
409
|
+
group_number = get_group_size() // 8
|
|
410
|
+
self.degree = int(self.degree / group_number)
|
|
411
|
+
group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)]
|
|
412
|
+
current_index = get_rank() // 8
|
|
413
|
+
server_group_name = "allreduce_" + str(current_index)
|
|
414
|
+
create_group(server_group_name, group_list[current_index])
|
|
415
|
+
group = server_group_name
|
|
416
|
+
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
|
|
417
|
+
self._get_attr_from_cell(network)
|
|
418
|
+
self.use_graceful_exit = os.environ.get("MS_ENABLE_GRACEFUL_EXIT") == "1"
|
|
419
|
+
if self.use_graceful_exit:
|
|
420
|
+
self.graceful_exit = ExitByRequest()
|
|
421
|
+
self.exit_param = Parameter(Tensor(False, mstype.bool_), name="graceful_exit") # update by reduce value
|
|
422
|
+
self.init_param = Parameter(Tensor([0], mstype.int32), name="graceful_init") # update by config file
|
|
423
|
+
|
|
424
|
+
def construct(self, *inputs):
|
|
425
|
+
if not self.sense_flag:
|
|
426
|
+
return self._no_sens_impl(*inputs)
|
|
427
|
+
loss = self.network(*inputs)
|
|
428
|
+
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
|
429
|
+
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
|
430
|
+
grads = self.grad_reducer(grads)
|
|
431
|
+
if self.use_graceful_exit:
|
|
432
|
+
grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param)
|
|
433
|
+
loss = F.depend(loss, self.optimizer(grads))
|
|
434
|
+
if self.return_grad:
|
|
435
|
+
grad_with_param_name = {}
|
|
436
|
+
for index, value in enumerate(grads):
|
|
437
|
+
grad_with_param_name[self.weights_name[index]] = value
|
|
438
|
+
return loss, grad_with_param_name
|
|
439
|
+
return loss
|
|
440
|
+
|
|
441
|
+
def _no_sens_impl(self, *inputs):
|
|
442
|
+
"""construct implementation when the 'sens' parameter is passed in."""
|
|
443
|
+
loss = self.network(*inputs)
|
|
444
|
+
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
445
|
+
grads = self.grad_reducer(grads)
|
|
446
|
+
if self.use_graceful_exit:
|
|
447
|
+
grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param)
|
|
448
|
+
loss = F.depend(loss, self.optimizer(grads))
|
|
449
|
+
if self.return_grad:
|
|
450
|
+
grad_with_param_name = {}
|
|
451
|
+
for index, value in enumerate(grads):
|
|
452
|
+
grad_with_param_name[self.weights_name[index]] = value
|
|
453
|
+
return loss, grad_with_param_name
|
|
454
|
+
return loss
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class GetNextSingleOp(Cell):
|
|
458
|
+
"""
|
|
459
|
+
Cell to run for getting the next operation.
|
|
460
|
+
|
|
461
|
+
For detailed information, refer to :class:`mindspore.ops.GetNext`.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
dataset_types (list[:class:`mindspore.dtype`]): The types of dataset.
|
|
465
|
+
dataset_shapes (list[tuple[int]]): The shapes of dataset.
|
|
466
|
+
queue_name (str): Queue name to fetch the data.
|
|
467
|
+
|
|
468
|
+
Outputs:
|
|
469
|
+
tuple[Tensor], the data gets from Dataset.
|
|
470
|
+
|
|
471
|
+
Supported Platforms:
|
|
472
|
+
``Ascend`` ``GPU``
|
|
473
|
+
|
|
474
|
+
Examples:
|
|
475
|
+
>>> import mindspore
|
|
476
|
+
>>> from mindspore import ops, nn
|
|
477
|
+
>>> from mindspore import dataset as ds
|
|
478
|
+
>>> from mindspore import dtype as mstype
|
|
479
|
+
>>>
|
|
480
|
+
>>> data_path = "/path/to/MNIST_Data/train/"
|
|
481
|
+
>>> train_dataset = ds.MnistDataset(data_path, num_samples=10)
|
|
482
|
+
>>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
|
|
483
|
+
>>> dataset = dataset_helper.iter.dataset
|
|
484
|
+
>>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
|
485
|
+
>>> queue_name = dataset.__transfer_dataset__.queue_name
|
|
486
|
+
>>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name)
|
|
487
|
+
>>> data, label = get_next_single_op_net()
|
|
488
|
+
>>> relu = ops.ReLU()
|
|
489
|
+
>>> result = relu(data.astype(mstype.float32))
|
|
490
|
+
>>> print(result.shape)
|
|
491
|
+
(28, 28, 1)
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
def __init__(self, dataset_types, dataset_shapes, queue_name):
|
|
495
|
+
super(GetNextSingleOp, self).__init__()
|
|
496
|
+
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
497
|
+
|
|
498
|
+
def construct(self):
|
|
499
|
+
return self.get_next()
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
class _VirtualDatasetCell(Cell):
|
|
503
|
+
"""
|
|
504
|
+
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
|
|
505
|
+
|
|
506
|
+
_VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
|
|
507
|
+
of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted
|
|
508
|
+
dynamically during the graph compile process.
|
|
509
|
+
|
|
510
|
+
Note:
|
|
511
|
+
Only used in semi auto parallel and auto parallel mode.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
backbone (Cell): The target network to wrap.
|
|
515
|
+
|
|
516
|
+
Examples:
|
|
517
|
+
>>> net = Net()
|
|
518
|
+
>>> net = _VirtualDatasetCell(net)
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
def __init__(self, backbone):
|
|
522
|
+
super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
|
|
523
|
+
self._backbone = backbone
|
|
524
|
+
self._virtual_dataset = _VirtualDataset()
|
|
525
|
+
self._get_attr_from_cell(backbone)
|
|
526
|
+
|
|
527
|
+
def construct(self, *inputs):
|
|
528
|
+
output = self._virtual_dataset(*inputs)
|
|
529
|
+
return self._backbone(*output)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@_primexpr
|
|
533
|
+
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
|
|
534
|
+
if F.isconstant(input_shape[0]) is False:
|
|
535
|
+
return
|
|
536
|
+
if input_shape[0] % micro_size != 0:
|
|
537
|
+
raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
|
|
538
|
+
f"divided by micro size({micro_size})")
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
class _MicroBatch(Cell):
|
|
542
|
+
"""
|
|
543
|
+
transform mini-batch to micro-batch in pipeline parallel.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
params (micro_size): The number of micro-batch.
|
|
547
|
+
"""
|
|
548
|
+
def __init__(self, micro_size):
|
|
549
|
+
super(_MicroBatch, self).__init__()
|
|
550
|
+
self.shape = P.Shape()
|
|
551
|
+
self.micro_size = micro_size
|
|
552
|
+
self.strided_slice = P.StridedSlice()
|
|
553
|
+
|
|
554
|
+
def construct(self, i, *inputs):
|
|
555
|
+
"""construct for _MicroBatch."""
|
|
556
|
+
micro_inputs = ()
|
|
557
|
+
for each_input in inputs:
|
|
558
|
+
input_shape = self.shape(each_input)
|
|
559
|
+
_check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
|
|
560
|
+
micro_batch_begin = (input_shape[0] // self.micro_size) * i
|
|
561
|
+
micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1)
|
|
562
|
+
strided_slice_begin = (micro_batch_begin,)
|
|
563
|
+
strided_slice_strides = (1,)
|
|
564
|
+
for _ in range(len(input_shape) - 1):
|
|
565
|
+
strided_slice_begin += (0,)
|
|
566
|
+
strided_slice_strides += (1,)
|
|
567
|
+
strided_slice_end = (micro_batch_end,)
|
|
568
|
+
strided_slice_end += input_shape[1:]
|
|
569
|
+
micro_input = self.strided_slice(each_input, strided_slice_begin, strided_slice_end, strided_slice_strides)
|
|
570
|
+
micro_inputs += (micro_input,)
|
|
571
|
+
return micro_inputs
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
class MicroBatchInterleaved(Cell):
|
|
575
|
+
"""
|
|
576
|
+
This function splits the input at the 0th into interleave_num pieces and then performs
|
|
577
|
+
the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode
|
|
578
|
+
and network, if the first slice data is calculating forward, the second slice data will execute the
|
|
579
|
+
communication operators at the same time, to achieve the performance acceleration of communication and computing
|
|
580
|
+
concurrency.
|
|
581
|
+
|
|
582
|
+
Note:
|
|
583
|
+
The output of the input network must be a single tensor.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
network (Cell): The target network to wrap.
|
|
587
|
+
interleave_num (int, optional): split num of batch size. Default: ``2`` .
|
|
588
|
+
|
|
589
|
+
Inputs:
|
|
590
|
+
tuple[Tensor]. It's the same with the input of the `network` .
|
|
591
|
+
|
|
592
|
+
Outputs:
|
|
593
|
+
Tensor. The output of the input `network` .
|
|
594
|
+
|
|
595
|
+
Supported Platforms:
|
|
596
|
+
``Ascend`` ``GPU``
|
|
597
|
+
|
|
598
|
+
Examples:
|
|
599
|
+
>>> import mindspore.nn as nn
|
|
600
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
601
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
602
|
+
>>> net = LeNet5()
|
|
603
|
+
>>> net = nn.MicroBatchInterleaved(net, 2)
|
|
604
|
+
"""
|
|
605
|
+
def __init__(self, network, interleave_num=2):
|
|
606
|
+
super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
|
|
607
|
+
if not isinstance(interleave_num, int):
|
|
608
|
+
raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, "
|
|
609
|
+
"but got the type : {}.".format(type(interleave_num)))
|
|
610
|
+
if interleave_num <= 0:
|
|
611
|
+
raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, "
|
|
612
|
+
"but got {}.".format(interleave_num))
|
|
613
|
+
self.network = network
|
|
614
|
+
self.interleave_num = interleave_num
|
|
615
|
+
self.interleave_inputs = nn.CellList()
|
|
616
|
+
self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True)
|
|
617
|
+
for _ in range(interleave_num):
|
|
618
|
+
interleave_data = _MicroBatch(interleave_num)
|
|
619
|
+
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
|
620
|
+
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
|
621
|
+
self.interleave_inputs.append(interleave_data)
|
|
622
|
+
self._get_attr_from_cell(network)
|
|
623
|
+
|
|
624
|
+
def construct(self, *inputs):
|
|
625
|
+
output = 0.0
|
|
626
|
+
for i in range(self.interleave_num):
|
|
627
|
+
interleave_input = self.interleave_inputs[i](i, *inputs)
|
|
628
|
+
output = self.add(output, self.network(*interleave_input))
|
|
629
|
+
return output
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
class PipelineCell(Cell):
|
|
633
|
+
"""
|
|
634
|
+
Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.
|
|
635
|
+
|
|
636
|
+
Note:
|
|
637
|
+
micro_size must be greater or equal to pipeline stages.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
network (Cell): The target network to wrap.
|
|
641
|
+
micro_size (int): MicroBatch size.
|
|
642
|
+
|
|
643
|
+
Supported Platforms:
|
|
644
|
+
``Ascend`` ``GPU``
|
|
645
|
+
|
|
646
|
+
Examples:
|
|
647
|
+
>>> import mindspore.nn as nn
|
|
648
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
649
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
650
|
+
>>> net = LeNet5()
|
|
651
|
+
>>> net = nn.PipelineCell(net, 4)
|
|
652
|
+
"""
|
|
653
|
+
def __init__(self, network, micro_size):
|
|
654
|
+
super(PipelineCell, self).__init__(auto_prefix=False)
|
|
655
|
+
self.network = network
|
|
656
|
+
self.micro_inputs = nn.CellList()
|
|
657
|
+
self.micro_size = micro_size
|
|
658
|
+
self.add_list = []
|
|
659
|
+
if not isinstance(network, Cell):
|
|
660
|
+
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
661
|
+
"but got the type : {}.".format(type(network)))
|
|
662
|
+
if not isinstance(micro_size, int):
|
|
663
|
+
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
664
|
+
"but got the type : {}.".format(type(micro_size)))
|
|
665
|
+
if micro_size <= 0:
|
|
666
|
+
raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
|
|
667
|
+
"but got {}.".format(micro_size))
|
|
668
|
+
for i in range(micro_size):
|
|
669
|
+
micro_input = _MicroBatch(micro_size)
|
|
670
|
+
self.micro_inputs.append(micro_input)
|
|
671
|
+
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
|
672
|
+
self.add_list.append(self.add)
|
|
673
|
+
self._get_attr_from_cell(network)
|
|
674
|
+
|
|
675
|
+
def construct(self, *inputs):
|
|
676
|
+
ret = None
|
|
677
|
+
for i in range(self.micro_size):
|
|
678
|
+
micro_input = self.micro_inputs[i](i, *inputs)
|
|
679
|
+
output = self.network(*micro_input)
|
|
680
|
+
if ret is not None:
|
|
681
|
+
ret = self.add_list[i](ret, output)
|
|
682
|
+
else:
|
|
683
|
+
ret = output
|
|
684
|
+
return ret
|
|
685
|
+
|
|
686
|
+
class GradAccumulationCell(Cell):
|
|
687
|
+
"""
|
|
688
|
+
Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
network (Cell): The target network to wrap.
|
|
692
|
+
micro_size (int): MicroBatch size.
|
|
693
|
+
|
|
694
|
+
Supported Platforms:
|
|
695
|
+
``Ascend`` ``GPU``
|
|
696
|
+
|
|
697
|
+
Examples:
|
|
698
|
+
>>> import mindspore.nn as nn
|
|
699
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
700
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
701
|
+
>>> net = LeNet5()
|
|
702
|
+
>>> net = nn.GradAccumulationCell(net, 4)
|
|
703
|
+
"""
|
|
704
|
+
def __init__(self, network, micro_size):
|
|
705
|
+
super(GradAccumulationCell, self).__init__(auto_prefix=False)
|
|
706
|
+
self.network = network
|
|
707
|
+
self.micro_inputs = nn.CellList()
|
|
708
|
+
self.micro_size = micro_size
|
|
709
|
+
self.add_list = []
|
|
710
|
+
if not isinstance(network, Cell):
|
|
711
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
|
|
712
|
+
"but got the type : {}.".format(type(network)))
|
|
713
|
+
if not isinstance(micro_size, int):
|
|
714
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
|
|
715
|
+
"but got the type : {}.".format(type(micro_size)))
|
|
716
|
+
if micro_size <= 0:
|
|
717
|
+
raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, "
|
|
718
|
+
"but got {}.".format(micro_size))
|
|
719
|
+
for i in range(micro_size):
|
|
720
|
+
micro_input = _MicroBatch(micro_size)
|
|
721
|
+
micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size)
|
|
722
|
+
self.micro_inputs.append(micro_input)
|
|
723
|
+
self.add = P.Add().add_prim_attr("forward_end", i)
|
|
724
|
+
self.add_list.append(self.add)
|
|
725
|
+
self._get_attr_from_cell(network)
|
|
726
|
+
|
|
727
|
+
def construct(self, *inputs):
|
|
728
|
+
ret = None
|
|
729
|
+
for i in range(self.micro_size):
|
|
730
|
+
micro_input = self.micro_inputs[i](i, *inputs)
|
|
731
|
+
output = self.network(*micro_input)
|
|
732
|
+
if ret is not None:
|
|
733
|
+
ret = self.add_list[i](ret, output)
|
|
734
|
+
else:
|
|
735
|
+
ret = output
|
|
736
|
+
return ret
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def _pipeline_clear_grad(accu_grad, grad):
|
|
740
|
+
accu_grad = F.depend(accu_grad, grad)
|
|
741
|
+
zeros = F.zeros_like(accu_grad)
|
|
742
|
+
return F.assign(accu_grad, zeros)
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
746
|
+
"""
|
|
747
|
+
Wraps the network with an optimizer in pipeline mode.
|
|
748
|
+
"""
|
|
749
|
+
def __init__(self, network, optimizer, sens=None):
|
|
750
|
+
super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens)
|
|
751
|
+
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
|
752
|
+
self.hyper_map = ops.HyperMap()
|
|
753
|
+
self.opt_shard = _get_enable_parallel_optimizer()
|
|
754
|
+
self._get_attr_from_cell(network)
|
|
755
|
+
self.enable_tft = False
|
|
756
|
+
|
|
757
|
+
def construct(self, *inputs):
|
|
758
|
+
if not self.sense_flag:
|
|
759
|
+
return self._no_sens_impl(*inputs)
|
|
760
|
+
loss = self.network(*inputs)
|
|
761
|
+
sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
|
762
|
+
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
|
763
|
+
accu_grads = ops.depend(self.accu_grads, grads)
|
|
764
|
+
if self.opt_shard:
|
|
765
|
+
succ = self.optimizer(grads)
|
|
766
|
+
else:
|
|
767
|
+
succ = self.optimizer(accu_grads)
|
|
768
|
+
loss = ops.depend(loss, succ)
|
|
769
|
+
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
|
770
|
+
loss = ops.depend(loss, clear)
|
|
771
|
+
return loss
|
|
772
|
+
|
|
773
|
+
def _no_sens_impl(self, *inputs):
|
|
774
|
+
"""construct implementation when the 'sens' parameter is passed in."""
|
|
775
|
+
loss = self.network(*inputs)
|
|
776
|
+
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
777
|
+
accu_grads = ops.depend(self.accu_grads, grads)
|
|
778
|
+
if self.opt_shard:
|
|
779
|
+
succ = self.optimizer(grads)
|
|
780
|
+
else:
|
|
781
|
+
succ = self.optimizer(accu_grads)
|
|
782
|
+
loss = ops.depend(loss, succ)
|
|
783
|
+
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
|
784
|
+
loss = ops.depend(loss, clear)
|
|
785
|
+
return loss
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
class AllreduceGraph(Cell):
|
|
789
|
+
"""
|
|
790
|
+
A allreduce graph to broadcast parameters.
|
|
791
|
+
"""
|
|
792
|
+
def __init__(self, inputs, group_name):
|
|
793
|
+
super(AllreduceGraph, self).__init__()
|
|
794
|
+
self.input_num = len(inputs)
|
|
795
|
+
self.inputs = inputs
|
|
796
|
+
self.allreduces = []
|
|
797
|
+
self.assigns = []
|
|
798
|
+
for _ in range(self.input_num):
|
|
799
|
+
self.allreduces.append(ops.AllReduce(op="sum", group=group_name))
|
|
800
|
+
self.assigns.append(ops.Assign())
|
|
801
|
+
|
|
802
|
+
def construct(self):
|
|
803
|
+
for i in range(self.input_num):
|
|
804
|
+
res = self.allreduces[i](self.inputs[i])
|
|
805
|
+
self.assigns[i](self.inputs[i], res)
|
|
806
|
+
return self.inputs
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
class VirtualDatasetCellTriple(Cell):
|
|
810
|
+
"""
|
|
811
|
+
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
|
|
812
|
+
|
|
813
|
+
VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
|
|
814
|
+
of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
|
|
815
|
+
dynamically during the graph compile process.
|
|
816
|
+
|
|
817
|
+
Note:
|
|
818
|
+
Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
|
|
819
|
+
_VirtualDatasetCell.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
backbone (Cell): The target network to wrap.
|
|
823
|
+
|
|
824
|
+
Examples:
|
|
825
|
+
>>> import mindspore.nn as nn
|
|
826
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
827
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
828
|
+
>>> net = LeNet5()
|
|
829
|
+
>>> net = nn.VirtualDatasetCellTriple(net)
|
|
830
|
+
"""
|
|
831
|
+
|
|
832
|
+
def __init__(self, backbone):
|
|
833
|
+
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
|
|
834
|
+
logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
|
|
835
|
+
self._backbone = backbone
|
|
836
|
+
self._get_attr_from_cell(backbone)
|
|
837
|
+
|
|
838
|
+
def construct(self, a, b, c):
|
|
839
|
+
return self._backbone(a, b, c)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
class WithEvalCell(Cell):
|
|
843
|
+
r"""
|
|
844
|
+
Wraps the forward network with the loss function.
|
|
845
|
+
|
|
846
|
+
It returns loss, forward output and label to calculate the metrics.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
network (Cell): The forward network.
|
|
850
|
+
loss_fn (Cell): The loss function.
|
|
851
|
+
add_cast_fp32 (bool): Whether to adjust the data type to float32. Default: ``False`` .
|
|
852
|
+
|
|
853
|
+
Inputs:
|
|
854
|
+
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
855
|
+
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
856
|
+
|
|
857
|
+
Outputs:
|
|
858
|
+
Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
|
|
859
|
+
and a label Tensor of shape :math:`(N, \ldots)`.
|
|
860
|
+
|
|
861
|
+
Raises:
|
|
862
|
+
TypeError: If `add_cast_fp32` is not a bool.
|
|
863
|
+
|
|
864
|
+
Supported Platforms:
|
|
865
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
866
|
+
|
|
867
|
+
Examples:
|
|
868
|
+
>>> import mindspore.nn as nn
|
|
869
|
+
>>> # Define a forward network without loss function, taking LeNet5 as an example.
|
|
870
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
871
|
+
>>> net = LeNet5()
|
|
872
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
873
|
+
>>> eval_net = nn.WithEvalCell(net, loss_fn)
|
|
874
|
+
"""
|
|
875
|
+
|
|
876
|
+
def __init__(self, network, loss_fn, add_cast_fp32=False):
|
|
877
|
+
super(WithEvalCell, self).__init__(auto_prefix=False)
|
|
878
|
+
self._network = network
|
|
879
|
+
self._loss_fn = loss_fn
|
|
880
|
+
self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
|
|
881
|
+
self._get_attr_from_cell(network)
|
|
882
|
+
|
|
883
|
+
def construct(self, data, label):
|
|
884
|
+
outputs = self._network(data)
|
|
885
|
+
if self.add_cast_fp32:
|
|
886
|
+
label = F.mixed_precision_cast(mstype.float32, label)
|
|
887
|
+
outputs = F.cast(outputs, mstype.float32)
|
|
888
|
+
loss = self._loss_fn(outputs, label)
|
|
889
|
+
return loss, outputs, label
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
class ParameterUpdate(Cell):
|
|
893
|
+
"""
|
|
894
|
+
Cell that updates parameter.
|
|
895
|
+
|
|
896
|
+
With this Cell, one can manually update `param` with the input `Tensor`.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
param (Parameter): The parameter to be updated manually.
|
|
900
|
+
|
|
901
|
+
Inputs:
|
|
902
|
+
- **x** (Tensor) - A tensor whose shape and type are the same with `param`.
|
|
903
|
+
|
|
904
|
+
Outputs:
|
|
905
|
+
Tensor, the updated value.
|
|
906
|
+
|
|
907
|
+
Raises:
|
|
908
|
+
KeyError: If parameter with the specified name does not exist.
|
|
909
|
+
|
|
910
|
+
Supported Platforms:
|
|
911
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
912
|
+
|
|
913
|
+
Examples:
|
|
914
|
+
>>> import numpy as np
|
|
915
|
+
>>> import mindspore
|
|
916
|
+
>>> from mindspore import nn, Tensor
|
|
917
|
+
>>> network = nn.Dense(3, 4)
|
|
918
|
+
>>> param = network.parameters_dict()['weight']
|
|
919
|
+
>>> update = nn.ParameterUpdate(param)
|
|
920
|
+
>>> update.phase = "update_param"
|
|
921
|
+
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
|
|
922
|
+
>>> output = update(weight)
|
|
923
|
+
>>> print(output)
|
|
924
|
+
[[ 0. 1. 2.]
|
|
925
|
+
[ 3. 4. 5.]
|
|
926
|
+
[ 6. 7. 8.]
|
|
927
|
+
[ 9. 10. 11.]]
|
|
928
|
+
"""
|
|
929
|
+
|
|
930
|
+
def __init__(self, param):
|
|
931
|
+
super(ParameterUpdate, self).__init__(auto_prefix=False)
|
|
932
|
+
if not isinstance(param, Parameter):
|
|
933
|
+
raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(type(param)))
|
|
934
|
+
self._param = param
|
|
935
|
+
|
|
936
|
+
def construct(self, x):
|
|
937
|
+
F.assign(self._param, x)
|
|
938
|
+
return x
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
class _BroadCastCell(Cell):
|
|
942
|
+
"""
|
|
943
|
+
Broadcast the parameters from device 0 to other devices.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
params (list): The parameters of Net.
|
|
947
|
+
"""
|
|
948
|
+
|
|
949
|
+
def __init__(self, params):
|
|
950
|
+
super(_BroadCastCell, self).__init__()
|
|
951
|
+
from mindspore.communication.management import get_group_size, create_group
|
|
952
|
+
from mindspore import context
|
|
953
|
+
self.map_ = C.Map()
|
|
954
|
+
self.params = tuple(params)
|
|
955
|
+
if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE:
|
|
956
|
+
rank_list = [id for id in range(0, get_group_size())]
|
|
957
|
+
create_group("BroadcastWorldGroup", rank_list)
|
|
958
|
+
self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup")
|
|
959
|
+
else:
|
|
960
|
+
self.broadcast = P.Broadcast(0)
|
|
961
|
+
self.add_flags(skip_auto_parallel_compile=True)
|
|
962
|
+
|
|
963
|
+
def construct(self):
|
|
964
|
+
datatypes = self.map_(F.partial(_get_datatype), self.params)
|
|
965
|
+
params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
|
|
966
|
+
params = self.broadcast(params)
|
|
967
|
+
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
|
|
968
|
+
return new_params
|