mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py
CHANGED
|
@@ -17,8 +17,7 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
19
|
from mindspore import nn
|
|
20
|
-
from mindspore
|
|
21
|
-
from mindspore._checkparam import Rel
|
|
20
|
+
from mindspore import _checkparam as validator
|
|
22
21
|
from mindspore.common import dtype as mstype
|
|
23
22
|
from mindspore.nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
|
|
24
23
|
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
|
@@ -27,12 +26,14 @@ from mindspore.parallel._utils import _get_pipeline_stages
|
|
|
27
26
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|
28
27
|
from mindspore import boost, context
|
|
29
28
|
from mindspore.ops import operations as P
|
|
29
|
+
from mindspore.ops import Primitive
|
|
30
|
+
from mindspore import log as logger
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
STREE = None
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
|
|
36
|
+
AMP_WHITE_LIST = [
|
|
36
37
|
nn.Conv1d,
|
|
37
38
|
nn.Conv2d,
|
|
38
39
|
nn.Conv3d,
|
|
@@ -42,11 +43,7 @@ AMP_WHITE_LIST_Cell = (
|
|
|
42
43
|
nn.Dense,
|
|
43
44
|
nn.LSTMCell,
|
|
44
45
|
nn.RNNCell,
|
|
45
|
-
nn.GRUCell
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
AMP_WHITE_LIST_OPS = (
|
|
46
|
+
nn.GRUCell,
|
|
50
47
|
P.Conv2D,
|
|
51
48
|
P.Conv3D,
|
|
52
49
|
P.Conv2DTranspose,
|
|
@@ -57,106 +54,173 @@ AMP_WHITE_LIST_OPS = (
|
|
|
57
54
|
P.PReLU,
|
|
58
55
|
P.ReLU,
|
|
59
56
|
P.Ger
|
|
60
|
-
|
|
57
|
+
]
|
|
61
58
|
|
|
62
59
|
|
|
63
|
-
AMP_BLACK_LIST =
|
|
60
|
+
AMP_BLACK_LIST = [
|
|
64
61
|
nn.BatchNorm1d,
|
|
65
62
|
nn.BatchNorm2d,
|
|
66
63
|
nn.BatchNorm3d,
|
|
67
64
|
nn.LayerNorm
|
|
68
|
-
|
|
65
|
+
]
|
|
69
66
|
|
|
70
67
|
|
|
71
68
|
class _OutputTo16(nn.Cell):
|
|
72
69
|
"""Wrap cell for amp. Cast network output back to float16."""
|
|
73
|
-
|
|
74
|
-
def __init__(self, op):
|
|
70
|
+
def __init__(self, backbone):
|
|
75
71
|
super(_OutputTo16, self).__init__(auto_prefix=False)
|
|
76
|
-
self.
|
|
72
|
+
self._backbone = backbone
|
|
73
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
74
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
77
75
|
|
|
78
76
|
def construct(self, x):
|
|
79
|
-
return F.cast(self.
|
|
77
|
+
return F.cast(self._backbone(x), mstype.float16)
|
|
80
78
|
|
|
81
79
|
|
|
82
80
|
class _OutputTo32(nn.Cell):
|
|
83
|
-
"Wrap loss for amp. Cast network output back to float32"
|
|
84
|
-
|
|
81
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
85
82
|
def __init__(self, backbone):
|
|
86
83
|
super(_OutputTo32, self).__init__(auto_prefix=False)
|
|
87
84
|
self._backbone = backbone
|
|
88
|
-
|
|
85
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
86
|
+
self._jit_config_dict = backbone.jit_config_dict
|
|
89
87
|
|
|
90
88
|
def construct(self, *inputs):
|
|
91
89
|
out = self._backbone(*inputs)
|
|
92
90
|
return F.mixed_precision_cast(mstype.float32, out)
|
|
93
91
|
|
|
94
92
|
|
|
95
|
-
def
|
|
93
|
+
def _allow_mix_precision(node, allowed_list) -> bool:
|
|
94
|
+
"""
|
|
95
|
+
Check whether current node need do mix precision. Follow conditions need to be satisfied:
|
|
96
|
+
1) Type of node is one of (Primitive, nn.Cell)
|
|
97
|
+
2) Node is not P.Cast()
|
|
98
|
+
3) to_float(mindspore.float16) is not set in Cell
|
|
99
|
+
"""
|
|
100
|
+
if node.get_instance() in allowed_list:
|
|
101
|
+
return True
|
|
102
|
+
if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
|
|
103
|
+
return False
|
|
104
|
+
if isinstance(node.get_instance(), P.Cast):
|
|
105
|
+
return False
|
|
106
|
+
if issubclass(node.get_instance_type(), nn.Cell):
|
|
107
|
+
# if cell is already in allowed_list, it means to_float(mindspore.float16) is set by amp.
|
|
108
|
+
# if cell is not in allowed_list, but has to_float(mindspore.float16),
|
|
109
|
+
# it means to_float(mindspore.float16) is set by user.
|
|
110
|
+
if node.get_instance().to_float_fp16:
|
|
111
|
+
return False
|
|
112
|
+
allowed_list.append(node.get_instance())
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _insert_cast_operator_process(node, stree):
|
|
96
117
|
"""insert cast for operators in white_list."""
|
|
97
118
|
new_cast_node = None
|
|
119
|
+
# insert cast float16 before the primitive operators
|
|
120
|
+
if issubclass(node.get_instance_type(), Primitive):
|
|
121
|
+
for idx in range(len(node.get_inputs())):
|
|
122
|
+
position = stree.before(node)
|
|
123
|
+
new_node = P.Cast()
|
|
124
|
+
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
|
|
125
|
+
"mindspore.float16"])
|
|
126
|
+
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
127
|
+
targets=['x_cast_{}'.format(node.get_name())],
|
|
128
|
+
args=arg,
|
|
129
|
+
name='incast_{}{}'.format(node.get_name(), idx))
|
|
130
|
+
stree.insert(position, new_cast_node)
|
|
131
|
+
node.set_arg_by_node(idx, new_cast_node)
|
|
132
|
+
# insert cast float16 before the Cell operators
|
|
133
|
+
elif issubclass(node.get_instance_type(), nn.Cell):
|
|
134
|
+
node.get_instance().to_float(mstype.float16)
|
|
135
|
+
# ignore if subclass is not one of (Primitive, nn.Cell)
|
|
136
|
+
else:
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
# insert cast float32 after the operators
|
|
140
|
+
position = stree.after(node)
|
|
141
|
+
new_node = P.Cast()
|
|
142
|
+
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
143
|
+
"mindspore.float32"])
|
|
144
|
+
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
145
|
+
targets=['x_cast_{}'.format(node.get_name())],
|
|
146
|
+
args=arg,
|
|
147
|
+
name='outcast_{}'.format(node.get_name()))
|
|
148
|
+
# insert node & unique names
|
|
149
|
+
stree.insert(position, new_cast_node)
|
|
150
|
+
# update argument names
|
|
151
|
+
for user in node.get_users():
|
|
152
|
+
if user.get_name() == new_cast_node.get_name():
|
|
153
|
+
continue
|
|
154
|
+
for idx, arg in enumerate(user.get_args()):
|
|
155
|
+
if arg == node.get_targets()[0]:
|
|
156
|
+
user.set_arg_by_node(idx, new_cast_node)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _insert_cast_operator_white_list(stree, white_list):
|
|
160
|
+
"""insert cast for operators in white_list."""
|
|
161
|
+
allowed_list = []
|
|
98
162
|
for node in stree.nodes():
|
|
99
163
|
if node.get_targets() is None:
|
|
100
164
|
continue
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
position = stree.before(node)
|
|
108
|
-
new_node = P.Cast()
|
|
109
|
-
arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
|
|
110
|
-
"mindspore.float16"])
|
|
111
|
-
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
112
|
-
targets=['x_cast_{}'.format(node.get_name())],
|
|
113
|
-
args=arg,
|
|
114
|
-
name='incast_{}{}'.format(node.get_name(), idx))
|
|
115
|
-
stree.insert(position, new_cast_node)
|
|
116
|
-
node.set_arg_by_node(idx, new_cast_node)
|
|
117
|
-
# insert cast before the Cell operators in white_list
|
|
118
|
-
elif node.get_instance_type() in AMP_WHITE_LIST_Cell:
|
|
119
|
-
in_white_list = True
|
|
120
|
-
node.get_instance().to_float(mstype.float16)
|
|
121
|
-
|
|
122
|
-
# insert cast after the operators in white_list
|
|
123
|
-
if in_white_list:
|
|
124
|
-
position = stree.after(node)
|
|
125
|
-
new_node = P.Cast()
|
|
126
|
-
arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
127
|
-
"mindspore.float32"])
|
|
128
|
-
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
129
|
-
targets=['x_cast_{}'.format(node.get_name())],
|
|
130
|
-
args=arg,
|
|
131
|
-
name='outcast_{}'.format(node.get_name()))
|
|
132
|
-
for i in range(len(node.get_users())):
|
|
133
|
-
follow_node = node.get_users()[i]
|
|
134
|
-
stree.insert(position, new_cast_node)
|
|
135
|
-
idx = follow_node.get_args().index(node.get_targets()[0])
|
|
136
|
-
follow_node.set_arg_by_node(idx, new_cast_node)
|
|
137
|
-
else:
|
|
165
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
166
|
+
for n in node.get_handler().node_list:
|
|
167
|
+
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
168
|
+
_insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
|
|
169
|
+
white_list)
|
|
170
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
138
171
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
172
|
+
_insert_cast_operator_white_list(substree, white_list)
|
|
173
|
+
elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list):
|
|
174
|
+
_insert_cast_operator_process(node, stree)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _need_removed_cast_pair(node):
|
|
178
|
+
"""check whether the cast pairs should be removed."""
|
|
179
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values(["mindspore.float16", "mindspore.float32"])
|
|
180
|
+
cast_dtype_f16 = cast_dtypes[0]
|
|
181
|
+
cast_dtype_f32 = cast_dtypes[1]
|
|
182
|
+
# current node should be P.Cast()(x, mindspore.float32)
|
|
183
|
+
if node.get_instance_type() != P.Cast:
|
|
150
184
|
return False
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
185
|
+
node_cast_type = node.get_args()[1]
|
|
186
|
+
if node_cast_type != cast_dtype_f32:
|
|
187
|
+
return False
|
|
188
|
+
# all user nodes should be P.Cast()(x, mindspore.float16) or Cell with to_float(mindspore.float16)
|
|
189
|
+
if not node.get_users():
|
|
190
|
+
return False
|
|
191
|
+
for user in node.get_users():
|
|
192
|
+
if isinstance(user.get_instance(), nn.Cell):
|
|
193
|
+
if not user.get_instance().to_float_fp16:
|
|
194
|
+
return False
|
|
195
|
+
elif user.get_instance_type() == P.Cast:
|
|
196
|
+
user_cast_type = user.get_args()[1]
|
|
197
|
+
if user_cast_type != cast_dtype_f16:
|
|
198
|
+
return False
|
|
199
|
+
else:
|
|
200
|
+
return False
|
|
201
|
+
return True
|
|
158
202
|
|
|
159
|
-
|
|
203
|
+
|
|
204
|
+
def _removed_cast_pair_process(stree, cast_f32_node):
|
|
205
|
+
"""remove the duplicated cast operators."""
|
|
206
|
+
for user_node in cast_f32_node.get_users():
|
|
207
|
+
# remove cast f16 nodes
|
|
208
|
+
if user_node.get_instance_type() == P.Cast:
|
|
209
|
+
cast_f16_node = user_node
|
|
210
|
+
# modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
|
|
211
|
+
for cast_f16_user in cast_f16_node.get_users():
|
|
212
|
+
for idx, arg in enumerate(cast_f16_user.get_args()):
|
|
213
|
+
if arg == cast_f16_node.get_targets()[0]:
|
|
214
|
+
cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
|
|
215
|
+
stree.erase_node(cast_f16_node)
|
|
216
|
+
# update args of cell f16 nodes
|
|
217
|
+
elif isinstance(user_node.get_instance(), nn.Cell):
|
|
218
|
+
cell_f16_node = user_node
|
|
219
|
+
for idx, arg in enumerate(cell_f16_node.get_args()):
|
|
220
|
+
if arg == cast_f32_node.get_targets()[0]:
|
|
221
|
+
cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
|
|
222
|
+
# remove the cast f32 node
|
|
223
|
+
stree.erase_node(cast_f32_node)
|
|
160
224
|
|
|
161
225
|
|
|
162
226
|
def _remove_duplicated_cast(stree):
|
|
@@ -164,36 +228,28 @@ def _remove_duplicated_cast(stree):
|
|
|
164
228
|
for node in stree.nodes():
|
|
165
229
|
if node.get_targets() is None:
|
|
166
230
|
continue
|
|
167
|
-
if node.get_node_type()
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
follow_node = node.get_users()[i]
|
|
173
|
-
for n in follow_node.get_users():
|
|
174
|
-
idx = n.get_args().index(follow_node.get_targets()[0])
|
|
175
|
-
n.set_arg_by_node(idx, node.get_inputs()[0])
|
|
176
|
-
stree.erase_node(follow_node)
|
|
177
|
-
# remove the current cast node
|
|
178
|
-
stree.erase_node(node)
|
|
179
|
-
else:
|
|
231
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
232
|
+
for n in node.get_handler().node_list:
|
|
233
|
+
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
234
|
+
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
|
235
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
180
236
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
181
237
|
_remove_duplicated_cast(substree)
|
|
238
|
+
elif _need_removed_cast_pair(node):
|
|
239
|
+
_removed_cast_pair_process(stree, node)
|
|
182
240
|
|
|
183
241
|
|
|
184
|
-
def _auto_white_list(network):
|
|
242
|
+
def _auto_white_list(network, white_list):
|
|
185
243
|
"""process the white list of network."""
|
|
186
244
|
global STREE
|
|
187
245
|
STREE = ms.rewrite.SymbolTree.create(network)
|
|
188
|
-
|
|
246
|
+
_insert_cast_operator_white_list(STREE, white_list)
|
|
189
247
|
_remove_duplicated_cast(STREE)
|
|
190
248
|
return STREE.get_network()
|
|
191
249
|
|
|
192
250
|
|
|
193
|
-
def _auto_black_list(network, black_list
|
|
251
|
+
def _auto_black_list(network, black_list):
|
|
194
252
|
"""process the black list of network."""
|
|
195
|
-
if black_list is None:
|
|
196
|
-
black_list = AMP_BLACK_LIST
|
|
197
253
|
network.to_float(mstype.float16)
|
|
198
254
|
cells = network.name_cells()
|
|
199
255
|
change = False
|
|
@@ -201,7 +257,7 @@ def _auto_black_list(network, black_list=None):
|
|
|
201
257
|
subcell = cells[name]
|
|
202
258
|
if subcell == network:
|
|
203
259
|
continue
|
|
204
|
-
if isinstance(subcell, black_list):
|
|
260
|
+
if isinstance(subcell, tuple(black_list)):
|
|
205
261
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
206
262
|
change = True
|
|
207
263
|
else:
|
|
@@ -234,12 +290,13 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
234
290
|
"""
|
|
235
291
|
if not isinstance(network, nn.Cell):
|
|
236
292
|
raise TypeError("The network type should be Cell.")
|
|
293
|
+
|
|
237
294
|
if amp_level == "O0":
|
|
238
295
|
pass
|
|
239
296
|
elif amp_level == "O1":
|
|
240
|
-
return _auto_white_list(network)
|
|
297
|
+
return _auto_white_list(network, AMP_WHITE_LIST)
|
|
241
298
|
elif amp_level == "O2":
|
|
242
|
-
_auto_black_list(network)
|
|
299
|
+
_auto_black_list(network, AMP_BLACK_LIST)
|
|
243
300
|
elif amp_level == "O3":
|
|
244
301
|
network.to_float(mstype.float16)
|
|
245
302
|
else:
|
|
@@ -257,7 +314,7 @@ def _do_keep_batchnorm_fp32(network):
|
|
|
257
314
|
subcell = cells[name]
|
|
258
315
|
if subcell == network:
|
|
259
316
|
continue
|
|
260
|
-
elif isinstance(subcell, AMP_BLACK_LIST):
|
|
317
|
+
elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
|
|
261
318
|
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
262
319
|
change = True
|
|
263
320
|
else:
|
|
@@ -308,8 +365,8 @@ def _check_level(level, boost_level):
|
|
|
308
365
|
if not isinstance(level, str):
|
|
309
366
|
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
|
|
310
367
|
but got type {}.".format(type(level)))
|
|
311
|
-
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'],
|
|
312
|
-
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'],
|
|
368
|
+
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
|
|
369
|
+
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
|
|
313
370
|
|
|
314
371
|
if level == "auto":
|
|
315
372
|
device_target = context.get_context('device_target')
|
|
@@ -331,13 +388,12 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
331
388
|
"""Add loss network."""
|
|
332
389
|
|
|
333
390
|
class WithLossCell(nn.Cell):
|
|
334
|
-
"Wrap loss for amp. Cast network output back to float32"
|
|
335
|
-
|
|
391
|
+
"""Wrap loss for amp. Cast network output back to float32."""
|
|
336
392
|
def __init__(self, backbone, loss_fn):
|
|
337
393
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
338
394
|
self._backbone = backbone
|
|
339
395
|
self._loss_fn = loss_fn
|
|
340
|
-
if backbone.jit_config_dict:
|
|
396
|
+
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
|
341
397
|
self._jit_config_dict = backbone.jit_config_dict
|
|
342
398
|
|
|
343
399
|
def construct(self, data, label):
|
|
@@ -366,6 +422,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
366
422
|
|
|
367
423
|
- "O0": Do not change.
|
|
368
424
|
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
|
425
|
+
The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
|
|
426
|
+
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
369
427
|
- "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
|
370
428
|
using dynamic loss scale.
|
|
371
429
|
- "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
|
@@ -460,3 +518,116 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
460
518
|
else:
|
|
461
519
|
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
462
520
|
return network
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_white_list():
|
|
524
|
+
"""
|
|
525
|
+
Provide a copy of internal white list used by auto mixed precision.
|
|
526
|
+
|
|
527
|
+
.. warning::
|
|
528
|
+
This is an experimental API that is subject to change or deletion.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
list, A copy of internal white list.
|
|
532
|
+
"""
|
|
533
|
+
white_list = AMP_WHITE_LIST.copy()
|
|
534
|
+
return white_list
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def get_black_list():
|
|
538
|
+
"""
|
|
539
|
+
Provide a copy of internal black list used by auto mixed precision.
|
|
540
|
+
|
|
541
|
+
.. warning::
|
|
542
|
+
This is an experimental API that is subject to change or deletion.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
list, A copy of internal black list.
|
|
546
|
+
"""
|
|
547
|
+
black_list = AMP_BLACK_LIST.copy()
|
|
548
|
+
return black_list
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
552
|
+
"""
|
|
553
|
+
Custom mixed precision by setting whitelist or blacklist.
|
|
554
|
+
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
555
|
+
When the `black_list` is provided, primitives and cells that are not in `black_list` will perform the pereision
|
|
556
|
+
conversion.
|
|
557
|
+
Only one of `white_list` and `black_list` should be provided.
|
|
558
|
+
|
|
559
|
+
.. warning::
|
|
560
|
+
This is an experimental API that is subject to change or deletion.
|
|
561
|
+
|
|
562
|
+
Note:
|
|
563
|
+
- `custom_mixed_precision` should not be used at the same time as `auto_mixed_precision` . When both
|
|
564
|
+
`build_train_network` and `custom_mixed_precision` are used, `build_train_network` need to be called with
|
|
565
|
+
`level='O0'` before call `custom_mixed_precision` .
|
|
566
|
+
- Primitives for blacklist is not support yet.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
network (Cell): Definition of the network.
|
|
570
|
+
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: None, means
|
|
571
|
+
white list is not used.
|
|
572
|
+
black_list (list[Primitive, Cell], optional): Black list of custom mixed precision. Defaults: None, means
|
|
573
|
+
black list is not used.
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
network (Cell), A network supporting mixed precision.
|
|
577
|
+
|
|
578
|
+
Raises:
|
|
579
|
+
TypeError: The network type is not Cell.
|
|
580
|
+
ValueError: Neither `white_list` nor `black_list` is provided.
|
|
581
|
+
ValueError: Both `white_list` and `black_list` are provided.
|
|
582
|
+
|
|
583
|
+
Examples:
|
|
584
|
+
>>> from mindspore import amp
|
|
585
|
+
>>> net = MyNet()
|
|
586
|
+
>>> custom_white_list = amp.get_white_list()
|
|
587
|
+
>>> custom_white_list.append(nn.Tanhshrink)
|
|
588
|
+
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
|
|
589
|
+
"""
|
|
590
|
+
if not isinstance(network, nn.Cell):
|
|
591
|
+
raise TypeError("The network type should be Cell.")
|
|
592
|
+
|
|
593
|
+
if white_list is None and black_list is None:
|
|
594
|
+
raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.")
|
|
595
|
+
|
|
596
|
+
if white_list is not None and black_list is not None:
|
|
597
|
+
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
|
|
598
|
+
"at the same time, please provide one or the other.")
|
|
599
|
+
|
|
600
|
+
if white_list is not None:
|
|
601
|
+
_list_check(white_list, "white_list")
|
|
602
|
+
return _auto_white_list(network, white_list)
|
|
603
|
+
|
|
604
|
+
_list_check(black_list, "black_list")
|
|
605
|
+
_auto_black_list(network, black_list)
|
|
606
|
+
network = _OutputTo32(network)
|
|
607
|
+
return network
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _list_check(custom_list: list, list_name: str):
|
|
611
|
+
"""
|
|
612
|
+
check whether custom list is valid
|
|
613
|
+
|
|
614
|
+
Raises:
|
|
615
|
+
TypeError: The type of custom_list is not list.
|
|
616
|
+
TypeError: The element in custom_list is not a class.
|
|
617
|
+
TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive'].
|
|
618
|
+
"""
|
|
619
|
+
if not isinstance(custom_list, list):
|
|
620
|
+
raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}")
|
|
621
|
+
|
|
622
|
+
for elem in custom_list:
|
|
623
|
+
if not isinstance(elem, type):
|
|
624
|
+
raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
|
|
625
|
+
|
|
626
|
+
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
|
627
|
+
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
|
|
628
|
+
f"but got {elem}")
|
|
629
|
+
|
|
630
|
+
if list_name == 'black_list':
|
|
631
|
+
for elem in AMP_BLACK_LIST:
|
|
632
|
+
if elem not in custom_list:
|
|
633
|
+
logger.warning(f"{elem} is removed from internal black list.")
|
|
@@ -22,7 +22,7 @@ from mindspore import log as logger
|
|
|
22
22
|
from mindspore.train.serialization import load_checkpoint, save_checkpoint
|
|
23
23
|
from mindspore.train.callback._callback import Callback
|
|
24
24
|
from mindspore.train._utils import _make_directory
|
|
25
|
-
from mindspore
|
|
25
|
+
from mindspore import _checkparam as Validator
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class BackupAndRestore(Callback):
|
|
@@ -34,7 +34,7 @@ class BackupAndRestore(Callback):
|
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
backup_dir (str): Path to store and load the checkpoint file.
|
|
37
|
-
save_freq(Union['epoch', int]): When set to
|
|
37
|
+
save_freq(Union['epoch', int]): When set to 'epoch' the callback saves the checkpoint at the end of
|
|
38
38
|
each epoch. When set to an integer, the callback saves the checkpoint
|
|
39
39
|
every `save_freq` epoch. Default: 'epoch'.
|
|
40
40
|
delete_checkpoint(bool): If `delete_checkpoint=True`, the checkpoint will be deleted after
|
|
@@ -49,8 +49,8 @@ class BackupAndRestore(Callback):
|
|
|
49
49
|
.. note::
|
|
50
50
|
Before running the following example, you need to customize the network LeNet5 and
|
|
51
51
|
dataset preparation function create_dataset. Refer to
|
|
52
|
-
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0
|
|
53
|
-
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0
|
|
52
|
+
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
|
|
53
|
+
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
|
|
54
54
|
|
|
55
55
|
>>> from mindspore import nn
|
|
56
56
|
>>> from mindspore.train import Model, BackupAndRestore
|
|
@@ -69,7 +69,7 @@ class BackupAndRestore(Callback):
|
|
|
69
69
|
ckpt_dir = _make_directory(backup_dir)
|
|
70
70
|
self.backup_file = os.path.join(ckpt_dir, 'backup.ckpt')
|
|
71
71
|
if save_freq != "epoch":
|
|
72
|
-
self.save_freq = Validator.check_positive_int(
|
|
72
|
+
self.save_freq = Validator.check_positive_int(save_freq)
|
|
73
73
|
else:
|
|
74
74
|
self.save_freq = 1
|
|
75
75
|
self.delete_checkpoint = Validator.check_bool(delete_checkpoint)
|
|
@@ -93,7 +93,7 @@ class Callback:
|
|
|
93
93
|
recording current attributes. Users can add custimized attributes to the information.
|
|
94
94
|
Training process can also be stopped by calling `request_stop` method. For details
|
|
95
95
|
of custom Callback, please check
|
|
96
|
-
`Callback <https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
96
|
+
`Callback <https://www.mindspore.cn/tutorials/experts/en/r2.0/debug/custom_debug.html>`_.
|
|
97
97
|
|
|
98
98
|
Examples:
|
|
99
99
|
>>> import numpy as np
|
|
@@ -437,7 +437,7 @@ class RunContext:
|
|
|
437
437
|
`RunContext.original_args()` and add extra attributes to the information, but also can stop the
|
|
438
438
|
training process by calling `request_stop` method. For details of custom Callback,
|
|
439
439
|
please check
|
|
440
|
-
`Callback <:https//www.mindspore.cn/tutorials/experts/en/r2.0
|
|
440
|
+
`Callback <:https//www.mindspore.cn/tutorials/experts/en/r2.0/debug/custom_debug.html>`_.
|
|
441
441
|
|
|
442
442
|
`RunContext.original_args()` holds the model context information as a dictionary variable, and
|
|
443
443
|
different attributes of the dictionary are stored in training or eval process. Details are as follows:
|
|
@@ -23,7 +23,7 @@ import threading
|
|
|
23
23
|
import mindspore.context as context
|
|
24
24
|
from mindspore import log as logger
|
|
25
25
|
from mindspore import nn
|
|
26
|
-
from mindspore
|
|
26
|
+
from mindspore import _checkparam as Validator
|
|
27
27
|
from mindspore.train._utils import _make_directory
|
|
28
28
|
from mindspore.train.serialization import save_checkpoint, _save_graph
|
|
29
29
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
@@ -107,8 +107,8 @@ class CheckpointConfig:
|
|
|
107
107
|
.. note::
|
|
108
108
|
Before running the following example, you need to customize the network LeNet5 and
|
|
109
109
|
dataset preparation function create_dataset. Refer to
|
|
110
|
-
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0
|
|
111
|
-
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0
|
|
110
|
+
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
|
|
111
|
+
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
|
|
112
112
|
|
|
113
113
|
>>> from mindspore import nn
|
|
114
114
|
>>> from mindspore.common.initializer import Normal
|
|
@@ -21,7 +21,7 @@ import numpy as np
|
|
|
21
21
|
|
|
22
22
|
from mindspore import ops, nn
|
|
23
23
|
from mindspore.common.tensor import Tensor
|
|
24
|
-
from mindspore
|
|
24
|
+
from mindspore import _checkparam as Validator
|
|
25
25
|
from mindspore.train.serialization import load_param_into_net
|
|
26
26
|
from mindspore import log as logger
|
|
27
27
|
from mindspore.ops import ReduceOp
|
|
@@ -85,8 +85,8 @@ class EarlyStopping(Callback):
|
|
|
85
85
|
.. note::
|
|
86
86
|
Before running the following example, you need to customize the network LeNet5 and
|
|
87
87
|
dataset preparation function create_dataset. Refer to
|
|
88
|
-
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0
|
|
89
|
-
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0
|
|
88
|
+
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
|
|
89
|
+
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
|
|
90
90
|
|
|
91
91
|
>>> from mindspore import nn
|
|
92
92
|
>>> from mindspore.train import Model, EarlyStopping
|
|
@@ -26,8 +26,8 @@ class LambdaCallback(Callback):
|
|
|
26
26
|
at the appropriate time (during `mindspore.train.Model.{train | eval | fit}`). Note that
|
|
27
27
|
each stage of callbacks expects one positional arguments: `run_context`.
|
|
28
28
|
|
|
29
|
-
|
|
30
|
-
This is an experimental
|
|
29
|
+
.. warning::
|
|
30
|
+
This is an experimental API that is subject to change or deletion.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
33
|
on_train_epoch_begin (Function): called at each train epoch begin.
|