mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
mindspore/train/amp.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Auto mixed precision."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
import inspect
|
|
18
|
+
import types
|
|
19
|
+
from typing import Any
|
|
20
|
+
import functools
|
|
21
|
+
import collections
|
|
22
|
+
|
|
23
|
+
import mindspore as ms
|
|
24
|
+
from mindspore import nn
|
|
25
|
+
from mindspore import _checkparam as validator
|
|
26
|
+
from mindspore.common import dtype as mstype
|
|
27
|
+
from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
|
|
28
|
+
from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
|
|
29
|
+
from mindspore.ops import functional as F
|
|
30
|
+
from mindspore.parallel._utils import _get_pipeline_stages
|
|
31
|
+
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|
32
|
+
from mindspore import boost, context
|
|
33
|
+
from mindspore.ops import operations as P
|
|
34
|
+
from mindspore.ops import Primitive
|
|
35
|
+
from mindspore.ops import auto_generate as gen
|
|
36
|
+
from mindspore import log as logger
|
|
37
|
+
from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
|
|
38
|
+
|
|
39
|
+
AMP_WHITE_LIST = [
|
|
40
|
+
nn.Conv1d,
|
|
41
|
+
nn.Conv2d,
|
|
42
|
+
nn.Conv3d,
|
|
43
|
+
nn.Conv1dTranspose,
|
|
44
|
+
nn.Conv2dTranspose,
|
|
45
|
+
nn.Conv3dTranspose,
|
|
46
|
+
nn.Dense,
|
|
47
|
+
nn.LSTMCell,
|
|
48
|
+
nn.RNNCell,
|
|
49
|
+
nn.GRUCell,
|
|
50
|
+
P.Conv2D,
|
|
51
|
+
P.Conv3D,
|
|
52
|
+
P.Conv2DTranspose,
|
|
53
|
+
P.Conv3DTranspose,
|
|
54
|
+
P.Conv2DBackpropInput,
|
|
55
|
+
P.MatMul,
|
|
56
|
+
P.BatchMatMul,
|
|
57
|
+
P.PReLU,
|
|
58
|
+
P.ReLU,
|
|
59
|
+
P.Ger,
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
AMP_BLACK_LIST = [
|
|
63
|
+
nn.BatchNorm1d,
|
|
64
|
+
nn.BatchNorm2d,
|
|
65
|
+
nn.BatchNorm3d,
|
|
66
|
+
nn.LayerNorm,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
AMP_AUTO_WHITE_LIST = [
|
|
70
|
+
P.Conv2D,
|
|
71
|
+
P.Conv3D,
|
|
72
|
+
P.Conv2DTranspose,
|
|
73
|
+
P.Conv3DTranspose,
|
|
74
|
+
gen.Convolution,
|
|
75
|
+
P.MatMul,
|
|
76
|
+
gen.MatMulExt,
|
|
77
|
+
P.BatchMatMul,
|
|
78
|
+
gen.BatchMatMulExt,
|
|
79
|
+
gen.PReLU,
|
|
80
|
+
P.Einsum,
|
|
81
|
+
gen.Dense,
|
|
82
|
+
gen.Addmm,
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
AMP_AUTO_BLACK_LIST = [
|
|
86
|
+
gen.Pow,
|
|
87
|
+
gen.ACos,
|
|
88
|
+
gen.Asin,
|
|
89
|
+
gen.Cosh,
|
|
90
|
+
P.Erfinv,
|
|
91
|
+
P.Exp,
|
|
92
|
+
P.Expm1,
|
|
93
|
+
P.Log,
|
|
94
|
+
P.Log1p,
|
|
95
|
+
P.Reciprocal,
|
|
96
|
+
P.Rsqrt,
|
|
97
|
+
P.Sinh,
|
|
98
|
+
P.Tan,
|
|
99
|
+
P.Softplus,
|
|
100
|
+
gen.SoftplusExt,
|
|
101
|
+
P.LayerNorm,
|
|
102
|
+
gen.LayerNormExt,
|
|
103
|
+
P.BatchNorm,
|
|
104
|
+
gen.GroupNorm,
|
|
105
|
+
P.KLDivLoss,
|
|
106
|
+
P.SmoothL1Loss,
|
|
107
|
+
P.MultilabelMarginLoss,
|
|
108
|
+
P.SoftMarginLoss,
|
|
109
|
+
P.TripletMarginLoss,
|
|
110
|
+
P.MultiMarginLoss,
|
|
111
|
+
P.BCEWithLogitsLoss,
|
|
112
|
+
P.Pdist,
|
|
113
|
+
P.Cdist,
|
|
114
|
+
P.Renorm,
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# Indicates which inputs of primitives need to be converted
|
|
118
|
+
AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
|
|
119
|
+
|
|
120
|
+
# Primitives in inner amp black list will not be converted in O2/O3
|
|
121
|
+
_INNER_AMP_BLACK_LIST = []
|
|
122
|
+
|
|
123
|
+
MS_AMP_BY_REWRITE = False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def amp_cast(value, dtype):
|
|
127
|
+
"""This function is used to insert cast operators for tensors during auto mixed precision."""
|
|
128
|
+
if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
|
|
129
|
+
return P.Cast()(value, dtype)
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
_amp_cast_op = amp_cast
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class _OutputTo16(nn.Cell):
|
|
136
|
+
"""Wrap cell for amp. Cast network output back to float16."""
|
|
137
|
+
def __init__(self, backbone, dtype=mstype.float16):
|
|
138
|
+
super(_OutputTo16, self).__init__(auto_prefix=False)
|
|
139
|
+
self._backbone = backbone
|
|
140
|
+
self.dtype = dtype
|
|
141
|
+
self._get_attr_from_cell(backbone)
|
|
142
|
+
|
|
143
|
+
def construct(self, *args, **kwargs):
|
|
144
|
+
return F.cast(self._backbone(*args, **kwargs), self.dtype)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class _OutputTo32(nn.Cell):
|
|
148
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
149
|
+
def __init__(self, backbone):
|
|
150
|
+
super(_OutputTo32, self).__init__(auto_prefix=False)
|
|
151
|
+
self._backbone = backbone
|
|
152
|
+
self._get_attr_from_cell(backbone)
|
|
153
|
+
|
|
154
|
+
def construct(self, *args, **kwargs):
|
|
155
|
+
out = self._backbone(*args, **kwargs)
|
|
156
|
+
return F.mixed_precision_cast(mstype.float32, out)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
|
|
160
|
+
"""
|
|
161
|
+
Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
|
|
162
|
+
1) Type of node is CallPrimitive and type of instance is Primitive
|
|
163
|
+
2) Type of instance is not P.Cast
|
|
164
|
+
3) force_cast is True, which means one of upper layer cells is under casting
|
|
165
|
+
4) white_list exist and type of node is in white_list
|
|
166
|
+
5) black_list exist and type of node is in not black_list
|
|
167
|
+
"""
|
|
168
|
+
if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
|
|
169
|
+
return False
|
|
170
|
+
if not inspect.isclass(node.get_instance_type()):
|
|
171
|
+
return False
|
|
172
|
+
if not issubclass(node.get_instance_type(), Primitive):
|
|
173
|
+
return False
|
|
174
|
+
if issubclass(node.get_instance_type(), P.Cast):
|
|
175
|
+
return False
|
|
176
|
+
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
|
|
177
|
+
return False
|
|
178
|
+
if force_cast:
|
|
179
|
+
return True
|
|
180
|
+
if white_list is not None and node.get_instance_type() in white_list:
|
|
181
|
+
return True
|
|
182
|
+
if black_list is not None and node.get_instance_type() not in black_list:
|
|
183
|
+
return True
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
|
|
188
|
+
"""Check whether cell precision is set by user."""
|
|
189
|
+
for flag in ["fp32", "fp16", "bf16"]:
|
|
190
|
+
if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
|
|
198
|
+
be satisfied:
|
|
199
|
+
1) Type of node is Tree and type of instance is Cell
|
|
200
|
+
2) Cell.to_float(xxx) is not set by user
|
|
201
|
+
3) force_cast is True, which means one of upper layer networks is under casting
|
|
202
|
+
4) white_list exist and type of node is in white_list
|
|
203
|
+
5) black_list exist and type of node is in not black_list
|
|
204
|
+
"""
|
|
205
|
+
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
|
206
|
+
return False
|
|
207
|
+
if not inspect.isclass(node.get_instance_type()):
|
|
208
|
+
return False
|
|
209
|
+
if not issubclass(node.get_instance_type(), nn.Cell):
|
|
210
|
+
return False
|
|
211
|
+
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
|
|
212
|
+
return False
|
|
213
|
+
if _precision_set_by_user(node.get_instance()):
|
|
214
|
+
return False
|
|
215
|
+
if force_cast:
|
|
216
|
+
return True
|
|
217
|
+
if white_list is not None and node.get_instance_type() in white_list:
|
|
218
|
+
return True
|
|
219
|
+
if black_list is not None and node.get_instance_type() not in black_list:
|
|
220
|
+
return True
|
|
221
|
+
return False
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _insert_cast_for_operator(node, dtype):
|
|
225
|
+
"""insert cast pair for node."""
|
|
226
|
+
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
|
|
227
|
+
stree = node.get_symbol_tree()
|
|
228
|
+
# insert cast fp16/bf16 for inputs of node
|
|
229
|
+
for idx, arg in enumerate(node.get_args()):
|
|
230
|
+
if arg.type != ms.rewrite.ValueType.NamingValue:
|
|
231
|
+
continue
|
|
232
|
+
incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
|
|
233
|
+
arg_providers = node.get_arg_providers()
|
|
234
|
+
if not arg_providers or idx not in arg_providers or \
|
|
235
|
+
len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
|
|
236
|
+
# create new target names when argument is used by other node
|
|
237
|
+
incast_targets = [stree.unique_name(f"{arg.value}_var")]
|
|
238
|
+
else:
|
|
239
|
+
incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
|
|
240
|
+
incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
|
|
241
|
+
stree.insert(stree.before(node), incast_node)
|
|
242
|
+
node.set_arg_by_node(idx, incast_node)
|
|
243
|
+
# insert cast fp32 for outputs of node
|
|
244
|
+
for _, target in enumerate(node.get_targets()):
|
|
245
|
+
if target.type != ms.rewrite.ValueType.NamingValue:
|
|
246
|
+
continue
|
|
247
|
+
outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
|
|
248
|
+
[target.scope, "mindspore"])
|
|
249
|
+
outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
|
|
250
|
+
outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
|
|
251
|
+
stree.insert(stree.after(node), outcast_node)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
|
|
255
|
+
"""insert cast for operators not in black_list."""
|
|
256
|
+
# get all nodes of stree exclude nodes in subtree.
|
|
257
|
+
all_nodes = stree.all_nodes(False)
|
|
258
|
+
for node in all_nodes:
|
|
259
|
+
if not node.get_targets():
|
|
260
|
+
continue
|
|
261
|
+
if _operator_need_cast(node, force_cast, white_list, black_list):
|
|
262
|
+
_insert_cast_for_operator(node, dtype)
|
|
263
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
264
|
+
force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
|
|
265
|
+
if not _precision_set_by_user(node.get_instance()):
|
|
266
|
+
subtree = node.get_sub_tree()
|
|
267
|
+
_insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _need_removed_cast_pair(node, dtype):
|
|
271
|
+
"""check whether the cast pairs should be removed."""
|
|
272
|
+
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
|
|
273
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
|
|
274
|
+
cast_dtype_f16 = cast_dtypes[0]
|
|
275
|
+
cast_dtype_f32 = cast_dtypes[1]
|
|
276
|
+
# current node should be cast fp32
|
|
277
|
+
if node.get_instance_type() != _amp_cast_op:
|
|
278
|
+
return False
|
|
279
|
+
node_cast_type = node.get_args()[1]
|
|
280
|
+
if node_cast_type != cast_dtype_f32:
|
|
281
|
+
return False
|
|
282
|
+
# all user nodes should be cast fp16/bf16
|
|
283
|
+
if not node.get_users():
|
|
284
|
+
return False
|
|
285
|
+
all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
|
|
286
|
+
for user in node.get_users():
|
|
287
|
+
# If ControlFlow node(e.g. if, for, while) exists between current node and user node,
|
|
288
|
+
# cast pair should not be removed.
|
|
289
|
+
middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
|
|
290
|
+
if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
|
|
291
|
+
return False
|
|
292
|
+
if user.get_instance_type() != _amp_cast_op:
|
|
293
|
+
return False
|
|
294
|
+
user_cast_type = user.get_args()[1]
|
|
295
|
+
if user_cast_type != cast_dtype_f16:
|
|
296
|
+
return False
|
|
297
|
+
# cast pair detected, check next user
|
|
298
|
+
continue
|
|
299
|
+
return True
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _remove_duplicated_cast(stree, dtype):
|
|
303
|
+
"""remove the duplicated cast operators."""
|
|
304
|
+
all_nodes = list(stree.nodes(all_nodes=True))
|
|
305
|
+
for node in all_nodes:
|
|
306
|
+
if _need_removed_cast_pair(node, dtype):
|
|
307
|
+
incast_nodes = node.get_users()
|
|
308
|
+
# remove cast fp16/bf16 nodes
|
|
309
|
+
for incast_node in incast_nodes:
|
|
310
|
+
# get_target_users() return {target0: [(user0, arg_idx), ...], ...}
|
|
311
|
+
target_users = list(incast_node.get_target_users().values())
|
|
312
|
+
if not target_users or not target_users[0]:
|
|
313
|
+
continue
|
|
314
|
+
for user_node, arg_idx in target_users[0]:
|
|
315
|
+
user_node.set_arg(arg_idx, incast_node.get_args()[0])
|
|
316
|
+
stree.erase(incast_node)
|
|
317
|
+
# remove the cast fp32 node
|
|
318
|
+
stree.erase(node)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
|
|
322
|
+
"""Implement auto mixed precision by rewrite"""
|
|
323
|
+
if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
|
|
324
|
+
raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
|
|
325
|
+
# enable rewrite configs for amp
|
|
326
|
+
ms.rewrite.common.namespace._ms_cells_to_subtree = True
|
|
327
|
+
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
|
|
328
|
+
# insert casts by rewrite
|
|
329
|
+
stree = ms.rewrite.SymbolTree.create(network)
|
|
330
|
+
_insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
|
|
331
|
+
_remove_duplicated_cast(stree, dtype)
|
|
332
|
+
new_net = stree.get_network()
|
|
333
|
+
# disable rewrite configs
|
|
334
|
+
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
|
|
335
|
+
ms.rewrite.common.namespace._ms_cells_to_subtree = False
|
|
336
|
+
ms.rewrite.common.config.clear_caches()
|
|
337
|
+
return new_net
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _auto_black_list(network, black_list, dtype):
|
|
341
|
+
"""process the black list of network."""
|
|
342
|
+
network.to_float(dtype)
|
|
343
|
+
cells = network.name_cells()
|
|
344
|
+
change = False
|
|
345
|
+
for name in cells:
|
|
346
|
+
subcell = cells[name]
|
|
347
|
+
if subcell == network:
|
|
348
|
+
continue
|
|
349
|
+
if isinstance(subcell, tuple(black_list)):
|
|
350
|
+
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
|
|
351
|
+
change = True
|
|
352
|
+
else:
|
|
353
|
+
_auto_black_list(subcell, black_list, dtype)
|
|
354
|
+
if isinstance(network, nn.SequentialCell) and change:
|
|
355
|
+
network.cell_list = list(network.cells())
|
|
356
|
+
return network
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
class amp_decorator:
|
|
360
|
+
"""
|
|
361
|
+
Auto mixed precision decorator.
|
|
362
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
363
|
+
"""
|
|
364
|
+
def __init__(self, amp_level, amp_dtype, white_list, black_list):
|
|
365
|
+
self.amp_level = amp_level
|
|
366
|
+
self.amp_dtype = amp_dtype
|
|
367
|
+
self.white_list = white_list
|
|
368
|
+
self.black_list = black_list
|
|
369
|
+
|
|
370
|
+
def __enter__(self):
|
|
371
|
+
push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
|
|
372
|
+
|
|
373
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
374
|
+
pop_amp_strategy()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
|
|
378
|
+
"""
|
|
379
|
+
Set auto mixed precision context decorator for object.
|
|
380
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
381
|
+
"""
|
|
382
|
+
if inspect.isfunction(obj) or inspect.ismethod(obj):
|
|
383
|
+
@functools.wraps(obj)
|
|
384
|
+
def wrapper(*args, **kwargs):
|
|
385
|
+
with amp_decorator(amp_level, amp_dtype, white_list, black_list):
|
|
386
|
+
return obj(*args, **kwargs)
|
|
387
|
+
return wrapper
|
|
388
|
+
if isinstance(obj, nn.Cell):
|
|
389
|
+
obj.construct = types.MethodType(
|
|
390
|
+
_set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
|
|
391
|
+
return obj
|
|
392
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
396
|
+
"""
|
|
397
|
+
Returns a network processed with auto mixed precision.
|
|
398
|
+
|
|
399
|
+
This interface will automatically perform mixed-precision processing on the input network, and the cells
|
|
400
|
+
and operators in the processed network will add precision conversion operations to calculate with lower
|
|
401
|
+
precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
|
|
402
|
+
converted to lower precision float, and calculation results are converted back to full precision float,
|
|
403
|
+
i.e. ``mstype.float32`` .
|
|
404
|
+
|
|
405
|
+
The `amp_level` and its corresponding lists determine which cells and operators are converted.
|
|
406
|
+
|
|
407
|
+
When `amp_level` is set to ``O0``, no cells and operators are converted.
|
|
408
|
+
|
|
409
|
+
When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
|
|
410
|
+
operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
|
|
411
|
+
|
|
412
|
+
When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
|
|
413
|
+
list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
|
|
414
|
+
|
|
415
|
+
When `amp_level` is set to ``O3``, all cells will be converted to low precision.
|
|
416
|
+
|
|
417
|
+
When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
|
|
418
|
+
operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
|
|
419
|
+
`promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
|
|
420
|
+
not listed will run in the type defined by their inputs.
|
|
421
|
+
|
|
422
|
+
Operators in `auto_whitelist` are:
|
|
423
|
+
|
|
424
|
+
``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
|
|
425
|
+
``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
|
|
426
|
+
|
|
427
|
+
Operators in `auto_blacklist` are:
|
|
428
|
+
|
|
429
|
+
``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
|
|
430
|
+
``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
|
|
431
|
+
``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
|
|
432
|
+
``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
|
|
433
|
+
``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
|
|
434
|
+
``Norm``
|
|
435
|
+
|
|
436
|
+
Operators in `promote_list` are:
|
|
437
|
+
|
|
438
|
+
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
|
|
439
|
+
``BiasAdd``
|
|
440
|
+
|
|
441
|
+
For details on automatic mixed precision, refer to
|
|
442
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
|
|
443
|
+
|
|
444
|
+
Note:
|
|
445
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
446
|
+
can result in a larger network hierarchy and slower performance.
|
|
447
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
448
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
449
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
450
|
+
- When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
|
|
451
|
+
may need to manually convert the type to avoid type inconsistency errors of the loss function.
|
|
452
|
+
- When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
|
|
453
|
+
specified by `to_float` takes effect first.
|
|
454
|
+
|
|
455
|
+
.. warning::
|
|
456
|
+
``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
|
|
460
|
+
is set to ``auto`` .
|
|
461
|
+
amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
462
|
+
|
|
463
|
+
- "O0": Do not change.
|
|
464
|
+
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
|
|
465
|
+
precision operations for the rest.
|
|
466
|
+
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
|
|
467
|
+
to lower precision operations.
|
|
468
|
+
- "O3": Cast network to lower precision.
|
|
469
|
+
- "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
|
|
470
|
+
`auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
|
|
471
|
+
to the higher accuracy float type of the operator inputs, and operators not listed will run in the
|
|
472
|
+
type defined by their inputs.
|
|
473
|
+
|
|
474
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
475
|
+
default: ``mstype.float16`` .
|
|
476
|
+
|
|
477
|
+
Raises:
|
|
478
|
+
TypeError: If `network` is not a Cell or a function.
|
|
479
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
480
|
+
ValueError: If `amp_level` is not within the supported range.
|
|
481
|
+
|
|
482
|
+
Examples:
|
|
483
|
+
>>> from mindspore import amp
|
|
484
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
485
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
486
|
+
>>> network = LeNet5()
|
|
487
|
+
>>> amp_level = "O1"
|
|
488
|
+
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
489
|
+
"""
|
|
490
|
+
if not isinstance(network, nn.Cell):
|
|
491
|
+
if amp_level == "auto":
|
|
492
|
+
if not inspect.isfunction(network) and not inspect.ismethod(network):
|
|
493
|
+
raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
|
|
494
|
+
# function is supported for amp_level 'auto'
|
|
495
|
+
else:
|
|
496
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
|
|
497
|
+
|
|
498
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
499
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
500
|
+
|
|
501
|
+
if amp_level == "O0":
|
|
502
|
+
return network
|
|
503
|
+
|
|
504
|
+
# Return network if the same amp level has already been configurated
|
|
505
|
+
if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
|
|
506
|
+
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
|
|
507
|
+
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
|
|
508
|
+
f"degradation.")
|
|
509
|
+
|
|
510
|
+
if amp_level == "O1":
|
|
511
|
+
network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST)
|
|
512
|
+
elif amp_level == "O2":
|
|
513
|
+
if MS_AMP_BY_REWRITE:
|
|
514
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST)
|
|
515
|
+
else:
|
|
516
|
+
network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
|
|
517
|
+
network = _OutputTo32(network)
|
|
518
|
+
elif amp_level == "O3":
|
|
519
|
+
if MS_AMP_BY_REWRITE:
|
|
520
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=[])
|
|
521
|
+
else:
|
|
522
|
+
network.to_float(dtype)
|
|
523
|
+
network = _OutputTo32(network)
|
|
524
|
+
elif amp_level == "auto":
|
|
525
|
+
white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
|
|
526
|
+
black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
|
|
527
|
+
# set amp_strategy attribute for the object
|
|
528
|
+
amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
529
|
+
setattr(network, "amp_strategy", amp_strategy)
|
|
530
|
+
# set amp_strategy context decorator for the object
|
|
531
|
+
network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
532
|
+
else:
|
|
533
|
+
raise ValueError(f"The amp level {amp_level} is not supported")
|
|
534
|
+
|
|
535
|
+
setattr(network, "_amp_level", amp_level)
|
|
536
|
+
|
|
537
|
+
return network
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def _do_keep_batchnorm_fp32(network):
|
|
541
|
+
"""Do keep batchnorm fp32."""
|
|
542
|
+
cells = network.name_cells()
|
|
543
|
+
change = False
|
|
544
|
+
for name in cells:
|
|
545
|
+
subcell = cells[name]
|
|
546
|
+
if subcell == network:
|
|
547
|
+
continue
|
|
548
|
+
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
|
|
549
|
+
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
550
|
+
change = True
|
|
551
|
+
else:
|
|
552
|
+
_do_keep_batchnorm_fp32(subcell)
|
|
553
|
+
if isinstance(network, nn.SequentialCell) and change:
|
|
554
|
+
network.cell_list = list(network.cells())
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
_config_level = {
|
|
558
|
+
"O0": {
|
|
559
|
+
"keep_batchnorm_fp32": False,
|
|
560
|
+
"cast_model_type": mstype.float32,
|
|
561
|
+
"loss_scale_manager": None},
|
|
562
|
+
"O1": {
|
|
563
|
+
"keep_batchnorm_fp32": False,
|
|
564
|
+
"cast_model_type": mstype.float32,
|
|
565
|
+
"loss_scale_manager": None},
|
|
566
|
+
"O2": {
|
|
567
|
+
"keep_batchnorm_fp32": True,
|
|
568
|
+
"cast_model_type": mstype.float16,
|
|
569
|
+
"loss_scale_manager": DynamicLossScaleManager()},
|
|
570
|
+
"O3": {
|
|
571
|
+
"keep_batchnorm_fp32": False,
|
|
572
|
+
"cast_model_type": mstype.float16,
|
|
573
|
+
"loss_scale_manager": None},
|
|
574
|
+
"auto": {
|
|
575
|
+
"keep_batchnorm_fp32": False,
|
|
576
|
+
"cast_model_type": mstype.float32,
|
|
577
|
+
"loss_scale_manager": None}}
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _check_kwargs(key_words):
|
|
581
|
+
"""Check kwargs."""
|
|
582
|
+
for arg in key_words:
|
|
583
|
+
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
|
|
584
|
+
raise ValueError(f"Unsupported arg '{arg}'")
|
|
585
|
+
|
|
586
|
+
if 'cast_model_type' in key_words:
|
|
587
|
+
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
|
588
|
+
[mstype.float16, mstype.float32], None)
|
|
589
|
+
if 'keep_batchnorm_fp32' in key_words:
|
|
590
|
+
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
|
|
591
|
+
if 'loss_scale_manager' in key_words:
|
|
592
|
+
loss_scale_manager = key_words['loss_scale_manager']
|
|
593
|
+
if loss_scale_manager:
|
|
594
|
+
validator.check_value_type('loss_scale_manager', loss_scale_manager,
|
|
595
|
+
[LossScaleManager, boost.GroupLossScaleManager])
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _check_level(level, boost_level):
|
|
599
|
+
"""Check level."""
|
|
600
|
+
if not isinstance(level, str):
|
|
601
|
+
raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
|
|
602
|
+
f"but got type {type(level)}.")
|
|
603
|
+
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
|
|
604
|
+
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
|
|
605
|
+
|
|
606
|
+
enable_boost = False
|
|
607
|
+
if boost_level in ["O1", "O2"]:
|
|
608
|
+
enable_boost = True
|
|
609
|
+
|
|
610
|
+
return level, enable_boost
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _add_loss_network(network, loss_fn, cast_model_type):
|
|
614
|
+
"""Add loss network."""
|
|
615
|
+
|
|
616
|
+
class WithLossCell(nn.Cell):
|
|
617
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
618
|
+
def __init__(self, backbone, loss_fn):
|
|
619
|
+
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
620
|
+
self._backbone = backbone
|
|
621
|
+
self._loss_fn = loss_fn
|
|
622
|
+
self._get_attr_from_cell(backbone)
|
|
623
|
+
|
|
624
|
+
def construct(self, data, label):
|
|
625
|
+
out = self._backbone(data)
|
|
626
|
+
label = F.mixed_precision_cast(mstype.float32, label)
|
|
627
|
+
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
|
628
|
+
|
|
629
|
+
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
|
630
|
+
if cast_model_type in (mstype.float16, mstype.bfloat16) or \
|
|
631
|
+
(hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
|
|
632
|
+
network = WithLossCell(network, loss_fn)
|
|
633
|
+
else:
|
|
634
|
+
network = nn.WithLossCell(network, loss_fn)
|
|
635
|
+
return network
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _is_grad_accumulation(mcell):
|
|
639
|
+
if mcell.cls_name == "GradAccumulationCell":
|
|
640
|
+
return True
|
|
641
|
+
for cell in mcell.cells():
|
|
642
|
+
if _is_grad_accumulation(cell):
|
|
643
|
+
return True
|
|
644
|
+
return False
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _auto_mixed_precision_process(network, config, level):
|
|
648
|
+
"""Auto mixed precision process."""
|
|
649
|
+
if MS_AMP_BY_REWRITE:
|
|
650
|
+
if config["cast_model_type"] == mstype.float16 or level == "O2":
|
|
651
|
+
level = "O2" if config["keep_batchnorm_fp32"] else "O3"
|
|
652
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
653
|
+
# cast_model_type set by kwargs
|
|
654
|
+
level = "O0"
|
|
655
|
+
network = auto_mixed_precision(network, level)
|
|
656
|
+
else:
|
|
657
|
+
if config["cast_model_type"] == mstype.float16:
|
|
658
|
+
network.to_float(mstype.float16)
|
|
659
|
+
|
|
660
|
+
if config["keep_batchnorm_fp32"]:
|
|
661
|
+
_do_keep_batchnorm_fp32(network)
|
|
662
|
+
elif not config["keep_batchnorm_fp32"] and level == "O2":
|
|
663
|
+
network.to_float(mstype.float16)
|
|
664
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
665
|
+
pass
|
|
666
|
+
else:
|
|
667
|
+
network = auto_mixed_precision(network, level)
|
|
668
|
+
return network
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
|
|
672
|
+
"""
|
|
673
|
+
Build the mixed precision training cell automatically.
|
|
674
|
+
|
|
675
|
+
Note:
|
|
676
|
+
- After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
|
|
677
|
+
to perform the precision conversion again. If `build_train_network` is used to train a converted network,
|
|
678
|
+
`level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
network (Cell): Definition of the network.
|
|
682
|
+
optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter.
|
|
683
|
+
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
|
|
684
|
+
Default: ``None`` .
|
|
685
|
+
level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
|
|
686
|
+
|
|
687
|
+
For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
688
|
+
|
|
689
|
+
Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
|
|
690
|
+
setting may be overwritten by settings in `kwargs`.
|
|
691
|
+
|
|
692
|
+
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
|
|
693
|
+
training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
|
|
694
|
+
|
|
695
|
+
- 'O0': Do not change.
|
|
696
|
+
- 'O1': Enable the boost mode, the performance is improved by about 20%, and
|
|
697
|
+
the accuracy is the same as the original accuracy.
|
|
698
|
+
- 'O2': Enable the boost mode, the performance is improved by about 30%, and
|
|
699
|
+
the accuracy is reduced by less than 3%.
|
|
700
|
+
|
|
701
|
+
If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
|
|
702
|
+
|
|
703
|
+
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
|
|
704
|
+
network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
|
|
705
|
+
to the type determined by `level` setting.
|
|
706
|
+
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If
|
|
707
|
+
set, the `level` setting will take no effect on this property.
|
|
708
|
+
loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of
|
|
709
|
+
:class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will
|
|
710
|
+
take no effect on this property.
|
|
711
|
+
|
|
712
|
+
Raises:
|
|
713
|
+
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
|
|
714
|
+
:class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ).
|
|
715
|
+
|
|
716
|
+
Examples:
|
|
717
|
+
>>> from mindspore import amp, nn
|
|
718
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
719
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
720
|
+
>>> network = LeNet5()
|
|
721
|
+
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
722
|
+
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
723
|
+
>>> amp_level="O3"
|
|
724
|
+
>>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)
|
|
725
|
+
"""
|
|
726
|
+
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
|
|
727
|
+
nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))
|
|
728
|
+
|
|
729
|
+
level, enable_boost = _check_level(level, boost_level)
|
|
730
|
+
|
|
731
|
+
_check_kwargs(kwargs)
|
|
732
|
+
config = dict(_config_level.get(level), **kwargs)
|
|
733
|
+
|
|
734
|
+
network = _auto_mixed_precision_process(network, config, level)
|
|
735
|
+
|
|
736
|
+
if loss_fn:
|
|
737
|
+
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
|
738
|
+
|
|
739
|
+
loss_scale = None
|
|
740
|
+
if config["loss_scale_manager"] is not None:
|
|
741
|
+
loss_scale_manager = config["loss_scale_manager"]
|
|
742
|
+
loss_scale = loss_scale_manager.get_loss_scale()
|
|
743
|
+
update_cell = loss_scale_manager.get_update_cell()
|
|
744
|
+
if update_cell is not None:
|
|
745
|
+
# only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
|
|
746
|
+
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
|
|
747
|
+
raise ValueError("Only `loss_scale_manager=None` or "
|
|
748
|
+
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
|
749
|
+
"are supported on device `CPU`. ")
|
|
750
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
751
|
+
network = _TrainGradAccuWithLossScaleCell(network, optimizer,
|
|
752
|
+
scale_sense=update_cell).set_train()
|
|
753
|
+
elif enable_boost:
|
|
754
|
+
network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
|
|
755
|
+
scale_sense=update_cell).set_train()
|
|
756
|
+
else:
|
|
757
|
+
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
|
758
|
+
scale_sense=update_cell).set_train()
|
|
759
|
+
return network
|
|
760
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
761
|
+
network = _TrainGradAccuStepCell(network, optimizer).set_train()
|
|
762
|
+
elif enable_boost:
|
|
763
|
+
network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
764
|
+
else:
|
|
765
|
+
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
766
|
+
return network
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def get_white_list():
|
|
770
|
+
"""
|
|
771
|
+
Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
|
|
772
|
+
|
|
773
|
+
The current built-in whitelist contents are:
|
|
774
|
+
|
|
775
|
+
[:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
|
|
776
|
+
:class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
|
|
777
|
+
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
|
|
778
|
+
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
|
|
779
|
+
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
|
|
780
|
+
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
|
|
781
|
+
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
list, A copy of internal white list.
|
|
785
|
+
|
|
786
|
+
Examples:
|
|
787
|
+
>>> from mindspore import amp
|
|
788
|
+
>>> white_list = amp.get_white_list()
|
|
789
|
+
>>> print(white_list)
|
|
790
|
+
[<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
|
|
791
|
+
<class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
|
|
792
|
+
<class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
|
|
793
|
+
<class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
|
|
794
|
+
<class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
|
|
795
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
|
|
796
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
|
|
797
|
+
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
|
|
798
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
|
|
799
|
+
<class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
|
|
800
|
+
<class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
|
|
801
|
+
<class 'mindspore.ops.operations.math_ops.Ger'>]
|
|
802
|
+
"""
|
|
803
|
+
white_list = AMP_WHITE_LIST.copy()
|
|
804
|
+
return white_list
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def get_black_list():
|
|
808
|
+
"""
|
|
809
|
+
Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
|
|
810
|
+
|
|
811
|
+
The current built-in blacklist contents are:
|
|
812
|
+
|
|
813
|
+
[:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
|
|
814
|
+
:class:`mindspore.nn.LayerNorm`]
|
|
815
|
+
|
|
816
|
+
Returns:
|
|
817
|
+
list, A copy of internal black list.
|
|
818
|
+
|
|
819
|
+
Examples:
|
|
820
|
+
>>> from mindspore import amp
|
|
821
|
+
>>> black_list = amp.get_black_list()
|
|
822
|
+
>>> print(black_list)
|
|
823
|
+
[<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
|
|
824
|
+
<class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
|
|
825
|
+
"""
|
|
826
|
+
black_list = AMP_BLACK_LIST.copy()
|
|
827
|
+
return black_list
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
|
|
831
|
+
"""
|
|
832
|
+
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
833
|
+
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
|
|
834
|
+
Only one of `white_list` and `black_list` should be provided.
|
|
835
|
+
|
|
836
|
+
Note:
|
|
837
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
838
|
+
can result in a larger network hierarchy and slower performance.
|
|
839
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
840
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
841
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
842
|
+
- Primitives for blacklist is not support yet.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
network (Cell): Definition of the network.
|
|
846
|
+
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
|
|
847
|
+
white list is not used.
|
|
848
|
+
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
849
|
+
black list is not used.
|
|
850
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
851
|
+
default: ``mstype.float16`` .
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
network (Cell), A network supporting mixed precision.
|
|
855
|
+
|
|
856
|
+
Raises:
|
|
857
|
+
TypeError: The network type is not Cell.
|
|
858
|
+
ValueError: Neither `white_list` nor `black_list` is provided.
|
|
859
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
860
|
+
ValueError: Both `white_list` and `black_list` are provided.
|
|
861
|
+
|
|
862
|
+
Examples:
|
|
863
|
+
>>> from mindspore import amp, nn
|
|
864
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
865
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
866
|
+
>>> net = LeNet5()
|
|
867
|
+
>>> custom_white_list = amp.get_white_list()
|
|
868
|
+
>>> custom_white_list.append(nn.Flatten)
|
|
869
|
+
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
|
|
870
|
+
"""
|
|
871
|
+
if not isinstance(network, nn.Cell):
|
|
872
|
+
raise TypeError("The network type should be Cell.")
|
|
873
|
+
|
|
874
|
+
if white_list is None and black_list is None:
|
|
875
|
+
raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.")
|
|
876
|
+
|
|
877
|
+
if white_list is not None and black_list is not None:
|
|
878
|
+
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
|
|
879
|
+
"at the same time, please provide one or the other.")
|
|
880
|
+
|
|
881
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
882
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
883
|
+
|
|
884
|
+
if white_list is not None:
|
|
885
|
+
_list_check(white_list, "white_list")
|
|
886
|
+
network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list)
|
|
887
|
+
else:
|
|
888
|
+
_list_check(black_list, "black_list")
|
|
889
|
+
if MS_AMP_BY_REWRITE:
|
|
890
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list)
|
|
891
|
+
else:
|
|
892
|
+
network = _auto_black_list(network, black_list, dtype)
|
|
893
|
+
network = _OutputTo32(network)
|
|
894
|
+
return network
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def _list_check(custom_list: list, list_name: str):
|
|
898
|
+
"""
|
|
899
|
+
check whether custom list is valid
|
|
900
|
+
|
|
901
|
+
Raises:
|
|
902
|
+
TypeError: The type of custom_list is not list.
|
|
903
|
+
TypeError: The element in custom_list is not a class.
|
|
904
|
+
TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive'].
|
|
905
|
+
"""
|
|
906
|
+
if not isinstance(custom_list, list):
|
|
907
|
+
raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}")
|
|
908
|
+
|
|
909
|
+
for elem in custom_list:
|
|
910
|
+
if not isinstance(elem, type):
|
|
911
|
+
raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
|
|
912
|
+
|
|
913
|
+
if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
|
914
|
+
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
|
|
915
|
+
f"but got {elem}")
|
|
916
|
+
|
|
917
|
+
if list_name == "black_list" and not issubclass(elem, nn.Cell):
|
|
918
|
+
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}")
|
|
919
|
+
|
|
920
|
+
if list_name == 'black_list':
|
|
921
|
+
for elem in AMP_BLACK_LIST:
|
|
922
|
+
if elem not in custom_list:
|
|
923
|
+
logger.warning(f"{elem} is removed from internal black list.")
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable
|
|
927
|
+
"""Configure auto mixed precision."""
|
|
928
|
+
global MS_AMP_BY_REWRITE
|
|
929
|
+
global _amp_cast_op
|
|
930
|
+
|
|
931
|
+
if enable_rewrite is not None:
|
|
932
|
+
MS_AMP_BY_REWRITE = enable_rewrite
|
|
933
|
+
|
|
934
|
+
if cast_op is not None:
|
|
935
|
+
_amp_cast_op = cast_op
|