mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
- mindspore/__init__.py +9 -4
- mindspore/_akg/akg/composite/build_module.py +11 -0
- mindspore/_akg/akg/config/repository_cuda.json +11 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +53 -58
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/config/super_bar_config.json +512 -0
- mindspore/context.py +291 -56
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/include/mindapi/base/type_id.h +42 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libicudata.so.69 +0 -0
- mindspore/lib/libicui18n.so.69 +0 -0
- mindspore/lib/libicuuc.so.69 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/scipy/linalg.py +13 -117
- mindspore/scipy/ops.py +5 -71
- mindspore/scipy/ops_grad.py +1 -25
- mindspore/scipy/ops_wrapper.py +1 -1
- mindspore/scipy/optimize/_bfgs.py +1 -1
- mindspore/scipy/optimize/_lagrange.py +200 -0
- mindspore/scipy/optimize/line_search.py +3 -2
- mindspore/scipy/optimize/minimize.py +43 -6
- mindspore/scipy/sparse/__init__.py +2 -2
- mindspore/scipy/sparse/linalg.py +5 -465
- mindspore/scipy/utils.py +2 -1
- mindspore/scipy/utils_const.py +7 -1
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/_utils.py
CHANGED
|
@@ -20,16 +20,18 @@ from collections.abc import Iterable
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
22
|
from mindspore.common.tensor import Tensor
|
|
23
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
23
24
|
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
|
24
25
|
from mindspore.common import dtype as mstype
|
|
25
26
|
from mindspore import log as logger
|
|
26
|
-
from mindspore
|
|
27
|
+
from mindspore import _checkparam as Validator
|
|
27
28
|
from mindspore.common.api import _cell_graph_executor
|
|
28
29
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
29
30
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
30
31
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
|
|
31
32
|
from mindspore.train.lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
|
32
33
|
from mindspore.parallel._parallel_serialization import _make_dir
|
|
34
|
+
from mindspore.ops.operations import debug_ops
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def _convert_type(types):
|
|
@@ -135,27 +137,48 @@ def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
|
|
|
135
137
|
return tensor_list_run, tensor_list_compile
|
|
136
138
|
|
|
137
139
|
|
|
138
|
-
def _check_to_numpy(plugin, tensor):
|
|
140
|
+
def _check_to_numpy(plugin, tensor, prim=None):
|
|
139
141
|
"""Check the tensor and return a numpy.ndarray."""
|
|
140
142
|
np_value = tensor.asnumpy()
|
|
141
143
|
np_value = np_value.copy()
|
|
144
|
+
summary_name = plugin.capitalize() + "Summary" if prim else "SummaryRecord"
|
|
142
145
|
if plugin == 'scalar':
|
|
143
146
|
if np_value.size == 1:
|
|
144
147
|
return np_value
|
|
145
|
-
raise ValueError(
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f'For "{summary_name}", the v rank must be less than or equal to 1, but got {len(np_value)}.')
|
|
146
150
|
if plugin == 'image':
|
|
147
151
|
if np_value.ndim == 4:
|
|
148
152
|
return np_value
|
|
149
|
-
raise ValueError('The tensor seems not to hold a valid image.')
|
|
153
|
+
raise ValueError(f'For "{summary_name}", The tensor seems not to hold a valid image.')
|
|
150
154
|
if plugin in ('tensor', 'histogram'):
|
|
151
155
|
if np_value.ndim > 0:
|
|
152
156
|
return np_value
|
|
153
|
-
raise ValueError('The
|
|
157
|
+
raise ValueError(f'For "{summary_name}", The value should not be empty.')
|
|
154
158
|
return np_value
|
|
155
159
|
|
|
156
160
|
|
|
161
|
+
def check_summary_param(summary_name, tag, tensor):
|
|
162
|
+
"""Checks the tag is valid for summary."""
|
|
163
|
+
plugin = summary_name.split('Summary')[0].lower()
|
|
164
|
+
try:
|
|
165
|
+
if not isinstance(tag, str) or not tag:
|
|
166
|
+
raise TypeError(f'For "{summary_name}", the name must be valid string, but got "{tag}".')
|
|
167
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
168
|
+
raise TypeError(f'For "{summary_name}", the parameter "value" expect to be Tensor, '
|
|
169
|
+
f'but got {type(tensor).__name__}')
|
|
170
|
+
_check_to_numpy(plugin, tensor, prim=True)
|
|
171
|
+
except TypeError as err:
|
|
172
|
+
raise TypeError(err)
|
|
173
|
+
except ValueError as err:
|
|
174
|
+
raise ValueError(err)
|
|
175
|
+
finally:
|
|
176
|
+
debug_ops.SUMMARY_TENSOR_CACHE = []
|
|
177
|
+
|
|
178
|
+
|
|
157
179
|
def _check_lineage_value(plugin, value):
|
|
158
180
|
"""Check the lineage value."""
|
|
181
|
+
|
|
159
182
|
def raises(plugin, prototype):
|
|
160
183
|
raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
|
|
161
184
|
|
mindspore/train/amp.py
CHANGED
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
"""Auto mixed precision."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
import mindspore as ms
|
|
18
19
|
from mindspore import nn
|
|
19
|
-
from mindspore
|
|
20
|
-
from mindspore._checkparam import Rel
|
|
20
|
+
from mindspore import _checkparam as validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
22
|
from mindspore.nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
|
|
23
23
|
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
|
@@ -25,73 +25,231 @@ from mindspore.ops import functional as F
|
|
|
25
25
|
from mindspore.parallel._utils import _get_pipeline_stages
|
|
26
26
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|
27
27
|
from mindspore import boost, context
|
|
28
|
+
from mindspore.ops import operations as P
|
|
29
|
+
from mindspore.ops import Primitive
|
|
30
|
+
from mindspore import log as logger
|
|
28
31
|
|
|
29
32
|
|
|
30
|
-
|
|
31
|
-
|
|
33
|
+
STREE = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
AMP_WHITE_LIST = [
|
|
32
37
|
nn.Conv1d,
|
|
33
38
|
nn.Conv2d,
|
|
34
39
|
nn.Conv3d,
|
|
35
40
|
nn.Conv1dTranspose,
|
|
36
41
|
nn.Conv2dTranspose,
|
|
37
|
-
nn.Conv3dTranspose
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
42
|
+
nn.Conv3dTranspose,
|
|
43
|
+
nn.Dense,
|
|
44
|
+
nn.LSTMCell,
|
|
45
|
+
nn.RNNCell,
|
|
46
|
+
nn.GRUCell,
|
|
47
|
+
P.Conv2D,
|
|
48
|
+
P.Conv3D,
|
|
49
|
+
P.Conv2DTranspose,
|
|
50
|
+
P.Conv3DTranspose,
|
|
51
|
+
P.Conv2DBackpropInput,
|
|
52
|
+
P.MatMul,
|
|
53
|
+
P.BatchMatMul,
|
|
54
|
+
P.PReLU,
|
|
55
|
+
P.ReLU,
|
|
56
|
+
P.Ger
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
AMP_BLACK_LIST = [
|
|
41
61
|
nn.BatchNorm1d,
|
|
42
62
|
nn.BatchNorm2d,
|
|
43
63
|
nn.BatchNorm3d,
|
|
44
64
|
nn.LayerNorm
|
|
45
|
-
|
|
65
|
+
]
|
|
46
66
|
|
|
47
67
|
|
|
48
68
|
class _OutputTo16(nn.Cell):
|
|
49
69
|
"""Wrap cell for amp. Cast network output back to float16."""
|
|
50
|
-
|
|
51
|
-
def __init__(self, op):
|
|
70
|
+
def __init__(self, backbone):
|
|
52
71
|
super(_OutputTo16, self).__init__(auto_prefix=False)
|
|
53
|
-
self.
|
|
72
|
+
self._backbone = backbone
|
|
73
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
74
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
54
75
|
|
|
55
76
|
def construct(self, x):
|
|
56
|
-
return F.cast(self.
|
|
77
|
+
return F.cast(self._backbone(x), mstype.float16)
|
|
57
78
|
|
|
58
79
|
|
|
59
80
|
class _OutputTo32(nn.Cell):
|
|
60
|
-
"Wrap loss for amp. Cast network output back to float32"
|
|
61
|
-
|
|
81
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
62
82
|
def __init__(self, backbone):
|
|
63
83
|
super(_OutputTo32, self).__init__(auto_prefix=False)
|
|
64
84
|
self._backbone = backbone
|
|
65
|
-
|
|
85
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
86
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
66
87
|
|
|
67
88
|
def construct(self, *inputs):
|
|
68
89
|
out = self._backbone(*inputs)
|
|
69
90
|
return F.mixed_precision_cast(mstype.float32, out)
|
|
70
91
|
|
|
71
92
|
|
|
72
|
-
def
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
93
|
+
def _allow_mix_precision(node, allowed_list) -> bool:
|
|
94
|
+
"""
|
|
95
|
+
Check whether current node need do mix precision. Follow conditions need to be satisfied:
|
|
96
|
+
1) Type of node is one of (Primitive, nn.Cell)
|
|
97
|
+
2) Node is not P.Cast()
|
|
98
|
+
3) to_float(mindspore.float16) is not set in Cell
|
|
99
|
+
"""
|
|
100
|
+
if node.get_instance() in allowed_list:
|
|
101
|
+
return True
|
|
102
|
+
if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
|
|
103
|
+
return False
|
|
104
|
+
if isinstance(node.get_instance(), P.Cast):
|
|
105
|
+
return False
|
|
106
|
+
if issubclass(node.get_instance_type(), nn.Cell):
|
|
107
|
+
# if cell is already in allowed_list, it means to_float(mindspore.float16) is set by amp.
|
|
108
|
+
# if cell is not in allowed_list, but has to_float(mindspore.float16),
|
|
109
|
+
# it means to_float(mindspore.float16) is set by user.
|
|
110
|
+
if node.get_instance().to_float_fp16:
|
|
111
|
+
return False
|
|
112
|
+
allowed_list.append(node.get_instance())
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _insert_cast_operator_process(node, stree):
|
|
117
|
+
"""insert cast for operators in white_list."""
|
|
118
|
+
new_cast_node = None
|
|
119
|
+
# insert cast float16 before the primitive operators
|
|
120
|
+
if issubclass(node.get_instance_type(), Primitive):
|
|
121
|
+
for idx in range(len(node.get_inputs())):
|
|
122
|
+
position = stree.before(node)
|
|
123
|
+
new_node = P.Cast()
|
|
124
|
+
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
|
|
125
|
+
"mindspore.float16"])
|
|
126
|
+
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
127
|
+
targets=['x_cast_{}'.format(node.get_name())],
|
|
128
|
+
args=arg,
|
|
129
|
+
name='incast_{}{}'.format(node.get_name(), idx))
|
|
130
|
+
stree.insert(position, new_cast_node)
|
|
131
|
+
node.set_arg_by_node(idx, new_cast_node)
|
|
132
|
+
# insert cast float16 before the Cell operators
|
|
133
|
+
elif issubclass(node.get_instance_type(), nn.Cell):
|
|
134
|
+
node.get_instance().to_float(mstype.float16)
|
|
135
|
+
# ignore if subclass is not one of (Primitive, nn.Cell)
|
|
136
|
+
else:
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
# insert cast float32 after the operators
|
|
140
|
+
position = stree.after(node)
|
|
141
|
+
new_node = P.Cast()
|
|
142
|
+
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
143
|
+
"mindspore.float32"])
|
|
144
|
+
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
145
|
+
targets=['x_cast_{}'.format(node.get_name())],
|
|
146
|
+
args=arg,
|
|
147
|
+
name='outcast_{}'.format(node.get_name()))
|
|
148
|
+
# insert node & unique names
|
|
149
|
+
stree.insert(position, new_cast_node)
|
|
150
|
+
# update argument names
|
|
151
|
+
for user in node.get_users():
|
|
152
|
+
if user.get_name() == new_cast_node.get_name():
|
|
81
153
|
continue
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
154
|
+
for idx, arg in enumerate(user.get_args()):
|
|
155
|
+
if arg == node.get_targets()[0]:
|
|
156
|
+
user.set_arg_by_node(idx, new_cast_node)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _insert_cast_operator_white_list(stree, white_list):
|
|
160
|
+
"""insert cast for operators in white_list."""
|
|
161
|
+
allowed_list = []
|
|
162
|
+
for node in stree.nodes():
|
|
163
|
+
if node.get_targets() is None:
|
|
164
|
+
continue
|
|
165
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
166
|
+
for n in node.get_handler().node_list:
|
|
167
|
+
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
168
|
+
_insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
|
|
169
|
+
white_list)
|
|
170
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
171
|
+
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
172
|
+
_insert_cast_operator_white_list(substree, white_list)
|
|
173
|
+
elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list):
|
|
174
|
+
_insert_cast_operator_process(node, stree)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _need_removed_cast_pair(node):
|
|
178
|
+
"""check whether the cast pairs should be removed."""
|
|
179
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values(["mindspore.float16", "mindspore.float32"])
|
|
180
|
+
cast_dtype_f16 = cast_dtypes[0]
|
|
181
|
+
cast_dtype_f32 = cast_dtypes[1]
|
|
182
|
+
# current node should be P.Cast()(x, mindspore.float32)
|
|
183
|
+
if node.get_instance_type() != P.Cast:
|
|
184
|
+
return False
|
|
185
|
+
node_cast_type = node.get_args()[1]
|
|
186
|
+
if node_cast_type != cast_dtype_f32:
|
|
187
|
+
return False
|
|
188
|
+
# all user nodes should be P.Cast()(x, mindspore.float16) or Cell with to_float(mindspore.float16)
|
|
189
|
+
if not node.get_users():
|
|
190
|
+
return False
|
|
191
|
+
for user in node.get_users():
|
|
192
|
+
if isinstance(user.get_instance(), nn.Cell):
|
|
193
|
+
if not user.get_instance().to_float_fp16:
|
|
194
|
+
return False
|
|
195
|
+
elif user.get_instance_type() == P.Cast:
|
|
196
|
+
user_cast_type = user.get_args()[1]
|
|
197
|
+
if user_cast_type != cast_dtype_f16:
|
|
198
|
+
return False
|
|
85
199
|
else:
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
200
|
+
return False
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _removed_cast_pair_process(stree, cast_f32_node):
|
|
205
|
+
"""remove the duplicated cast operators."""
|
|
206
|
+
for user_node in cast_f32_node.get_users():
|
|
207
|
+
# remove cast f16 nodes
|
|
208
|
+
if user_node.get_instance_type() == P.Cast:
|
|
209
|
+
cast_f16_node = user_node
|
|
210
|
+
# modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
|
|
211
|
+
for cast_f16_user in cast_f16_node.get_users():
|
|
212
|
+
for idx, arg in enumerate(cast_f16_user.get_args()):
|
|
213
|
+
if arg == cast_f16_node.get_targets()[0]:
|
|
214
|
+
cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
|
|
215
|
+
stree.erase_node(cast_f16_node)
|
|
216
|
+
# update args of cell f16 nodes
|
|
217
|
+
elif isinstance(user_node.get_instance(), nn.Cell):
|
|
218
|
+
cell_f16_node = user_node
|
|
219
|
+
for idx, arg in enumerate(cell_f16_node.get_args()):
|
|
220
|
+
if arg == cast_f32_node.get_targets()[0]:
|
|
221
|
+
cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
|
|
222
|
+
# remove the cast f32 node
|
|
223
|
+
stree.erase_node(cast_f32_node)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _remove_duplicated_cast(stree):
|
|
227
|
+
"""remove the duplicated cast operators."""
|
|
228
|
+
for node in stree.nodes():
|
|
229
|
+
if node.get_targets() is None:
|
|
230
|
+
continue
|
|
231
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
232
|
+
for n in node.get_handler().node_list:
|
|
233
|
+
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
234
|
+
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
|
235
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
236
|
+
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
237
|
+
_remove_duplicated_cast(substree)
|
|
238
|
+
elif _need_removed_cast_pair(node):
|
|
239
|
+
_removed_cast_pair_process(stree, node)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _auto_white_list(network, white_list):
|
|
243
|
+
"""process the white list of network."""
|
|
244
|
+
global STREE
|
|
245
|
+
STREE = ms.rewrite.SymbolTree.create(network)
|
|
246
|
+
_insert_cast_operator_white_list(STREE, white_list)
|
|
247
|
+
_remove_duplicated_cast(STREE)
|
|
248
|
+
return STREE.get_network()
|
|
89
249
|
|
|
90
250
|
|
|
91
|
-
def _auto_black_list(network, black_list
|
|
251
|
+
def _auto_black_list(network, black_list):
|
|
92
252
|
"""process the black list of network."""
|
|
93
|
-
if black_list is None:
|
|
94
|
-
black_list = AMP_BLACK_LIST
|
|
95
253
|
network.to_float(mstype.float16)
|
|
96
254
|
cells = network.name_cells()
|
|
97
255
|
change = False
|
|
@@ -99,7 +257,7 @@ def _auto_black_list(network, black_list=None):
|
|
|
99
257
|
subcell = cells[name]
|
|
100
258
|
if subcell == network:
|
|
101
259
|
continue
|
|
102
|
-
if isinstance(subcell, black_list):
|
|
260
|
+
if isinstance(subcell, tuple(black_list)):
|
|
103
261
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
104
262
|
change = True
|
|
105
263
|
else:
|
|
@@ -117,7 +275,7 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
117
275
|
amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0".
|
|
118
276
|
|
|
119
277
|
- "O0": Do not change.
|
|
120
|
-
- "O1":
|
|
278
|
+
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
|
121
279
|
- "O2": Cast network to float16, keep operators in black_list run in float32,
|
|
122
280
|
- "O3": Cast network to float16.
|
|
123
281
|
|
|
@@ -125,25 +283,24 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
125
283
|
ValueError: If amp level is not supported.
|
|
126
284
|
|
|
127
285
|
Examples:
|
|
128
|
-
>>> from
|
|
286
|
+
>>> from mindspore import amp, nn
|
|
129
287
|
>>> network = LeNet5()
|
|
130
288
|
>>> amp_level = "O1"
|
|
131
289
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
132
290
|
"""
|
|
133
291
|
if not isinstance(network, nn.Cell):
|
|
134
292
|
raise TypeError("The network type should be Cell.")
|
|
293
|
+
|
|
135
294
|
if amp_level == "O0":
|
|
136
295
|
pass
|
|
137
296
|
elif amp_level == "O1":
|
|
138
|
-
_auto_white_list(network)
|
|
139
|
-
return network
|
|
297
|
+
return _auto_white_list(network, AMP_WHITE_LIST)
|
|
140
298
|
elif amp_level == "O2":
|
|
141
|
-
_auto_black_list(network)
|
|
299
|
+
_auto_black_list(network, AMP_BLACK_LIST)
|
|
142
300
|
elif amp_level == "O3":
|
|
143
301
|
network.to_float(mstype.float16)
|
|
144
302
|
else:
|
|
145
303
|
raise ValueError("The amp level {} is not supported".format(amp_level))
|
|
146
|
-
|
|
147
304
|
if amp_level in ("O2", "O3"):
|
|
148
305
|
network = _OutputTo32(network)
|
|
149
306
|
return network
|
|
@@ -157,7 +314,7 @@ def _do_keep_batchnorm_fp32(network):
|
|
|
157
314
|
subcell = cells[name]
|
|
158
315
|
if subcell == network:
|
|
159
316
|
continue
|
|
160
|
-
elif isinstance(subcell, AMP_BLACK_LIST):
|
|
317
|
+
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
|
|
161
318
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
162
319
|
change = True
|
|
163
320
|
else:
|
|
@@ -208,8 +365,8 @@ def _check_level(level, boost_level):
|
|
|
208
365
|
if not isinstance(level, str):
|
|
209
366
|
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
|
|
210
367
|
but got type {}.".format(type(level)))
|
|
211
|
-
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'],
|
|
212
|
-
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'],
|
|
368
|
+
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
|
|
369
|
+
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
|
|
213
370
|
|
|
214
371
|
if level == "auto":
|
|
215
372
|
device_target = context.get_context('device_target')
|
|
@@ -231,13 +388,12 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
231
388
|
"""Add loss network."""
|
|
232
389
|
|
|
233
390
|
class WithLossCell(nn.Cell):
|
|
234
|
-
"Wrap loss for amp. Cast network output back to float32"
|
|
235
|
-
|
|
391
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
236
392
|
def __init__(self, backbone, loss_fn):
|
|
237
393
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
238
394
|
self._backbone = backbone
|
|
239
395
|
self._loss_fn = loss_fn
|
|
240
|
-
if backbone.jit_config_dict:
|
|
396
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
241
397
|
self._jit_config_dict = backbone.jit_config_dict
|
|
242
398
|
|
|
243
399
|
def construct(self, data, label):
|
|
@@ -265,7 +421,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
265
421
|
level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
|
|
266
422
|
|
|
267
423
|
- "O0": Do not change.
|
|
268
|
-
- "O1":
|
|
424
|
+
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
|
425
|
+
The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
|
|
426
|
+
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
269
427
|
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
|
270
428
|
using dynamic loss scale.
|
|
271
429
|
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
|
@@ -302,7 +460,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
302
460
|
(with property `drop_overflow_update=False` ).
|
|
303
461
|
|
|
304
462
|
Examples:
|
|
305
|
-
>>> from
|
|
463
|
+
>>> from mindspore import amp, nn
|
|
306
464
|
>>> network = LeNet5()
|
|
307
465
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
308
466
|
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
@@ -327,7 +485,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
327
485
|
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
328
486
|
pass
|
|
329
487
|
else:
|
|
330
|
-
auto_mixed_precision(network, level)
|
|
488
|
+
network = auto_mixed_precision(network, level)
|
|
331
489
|
|
|
332
490
|
if loss_fn:
|
|
333
491
|
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
|
@@ -360,3 +518,116 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
360
518
|
else:
|
|
361
519
|
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
362
520
|
return network
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_white_list():
|
|
524
|
+
"""
|
|
525
|
+
Provide a copy of internal white list used by auto mixed precision.
|
|
526
|
+
|
|
527
|
+
.. warning::
|
|
528
|
+
This is an experimental API that is subject to change or deletion.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
list, A copy of internal white list.
|
|
532
|
+
"""
|
|
533
|
+
white_list = AMP_WHITE_LIST.copy()
|
|
534
|
+
return white_list
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def get_black_list():
|
|
538
|
+
"""
|
|
539
|
+
Provide a copy of internal black list used by auto mixed precision.
|
|
540
|
+
|
|
541
|
+
.. warning::
|
|
542
|
+
This is an experimental API that is subject to change or deletion.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
list, A copy of internal black list.
|
|
546
|
+
"""
|
|
547
|
+
black_list = AMP_BLACK_LIST.copy()
|
|
548
|
+
return black_list
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
552
|
+
"""
|
|
553
|
+
Custom mixed precision by setting whitelist or blacklist.
|
|
554
|
+
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
555
|
+
When the `black_list` is provided, primitives and cells that are not in `black_list` will perform the pereision
|
|
556
|
+
conversion.
|
|
557
|
+
Only one of `white_list` and `black_list` should be provided.
|
|
558
|
+
|
|
559
|
+
.. warning::
|
|
560
|
+
This is an experimental API that is subject to change or deletion.
|
|
561
|
+
|
|
562
|
+
Note:
|
|
563
|
+
- `custom_mixed_precision` should not be used at the same time as `auto_mixed_precision` . When both
|
|
564
|
+
`build_train_network` and `custom_mixed_precision` are used, `build_train_network` need to be called with
|
|
565
|
+
`level='O0'` before call `custom_mixed_precision` .
|
|
566
|
+
- Primitives for blacklist is not support yet.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
network (Cell): Definition of the network.
|
|
570
|
+
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: None, means
|
|
571
|
+
white list is not used.
|
|
572
|
+
black_list (list[Primitive, Cell], optional): Black list of custom mixed precision. Defaults: None, means
|
|
573
|
+
black list is not used.
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
network (Cell), A network supporting mixed precision.
|
|
577
|
+
|
|
578
|
+
Raises:
|
|
579
|
+
TypeError: The network type is not Cell.
|
|
580
|
+
ValueError: Neither `white_list` nor `black_list` is provided.
|
|
581
|
+
ValueError: Both `white_list` and `black_list` are provided.
|
|
582
|
+
|
|
583
|
+
Examples:
|
|
584
|
+
>>> from mindspore import amp
|
|
585
|
+
>>> net = MyNet()
|
|
586
|
+
>>> custom_white_list = amp.get_white_list()
|
|
587
|
+
>>> custom_white_list.append(nn.Tanhshrink)
|
|
588
|
+
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
|
|
589
|
+
"""
|
|
590
|
+
if not isinstance(network, nn.Cell):
|
|
591
|
+
raise TypeError("The network type should be Cell.")
|
|
592
|
+
|
|
593
|
+
if white_list is None and black_list is None:
|
|
594
|
+
raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.")
|
|
595
|
+
|
|
596
|
+
if white_list is not None and black_list is not None:
|
|
597
|
+
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
|
|
598
|
+
"at the same time, please provide one or the other.")
|
|
599
|
+
|
|
600
|
+
if white_list is not None:
|
|
601
|
+
_list_check(white_list, "white_list")
|
|
602
|
+
return _auto_white_list(network, white_list)
|
|
603
|
+
|
|
604
|
+
_list_check(black_list, "black_list")
|
|
605
|
+
_auto_black_list(network, black_list)
|
|
606
|
+
network = _OutputTo32(network)
|
|
607
|
+
return network
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _list_check(custom_list: list, list_name: str):
|
|
611
|
+
"""
|
|
612
|
+
check whether custom list is valid
|
|
613
|
+
|
|
614
|
+
Raises:
|
|
615
|
+
TypeError: The type of custom_list is not list.
|
|
616
|
+
TypeError: The element in custom_list is not a class.
|
|
617
|
+
TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive'].
|
|
618
|
+
"""
|
|
619
|
+
if not isinstance(custom_list, list):
|
|
620
|
+
raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}")
|
|
621
|
+
|
|
622
|
+
for elem in custom_list:
|
|
623
|
+
if not isinstance(elem, type):
|
|
624
|
+
raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
|
|
625
|
+
|
|
626
|
+
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
|
627
|
+
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
|
|
628
|
+
f"but got {elem}")
|
|
629
|
+
|
|
630
|
+
if list_name == 'black_list':
|
|
631
|
+
for elem in AMP_BLACK_LIST:
|
|
632
|
+
if elem not in custom_list:
|
|
633
|
+
logger.warning(f"{elem} is removed from internal black list.")
|
|
@@ -33,7 +33,9 @@ from mindspore.train.callback._history import History
|
|
|
33
33
|
from mindspore.train.callback._lambda_callback import LambdaCallback
|
|
34
34
|
from mindspore.train.callback._early_stop import EarlyStopping
|
|
35
35
|
from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
|
|
36
|
+
from mindspore.train.callback._on_request_exit import OnRequestExit
|
|
37
|
+
from mindspore.train.callback._backup_and_restore import BackupAndRestore
|
|
36
38
|
|
|
37
39
|
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
|
38
40
|
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
|
|
39
|
-
"History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping"]
|
|
41
|
+
"History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping", "OnRequestExit", "BackupAndRestore"]
|