mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/amp.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,17 +16,20 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
|
|
20
|
-
from .
|
|
19
|
+
from mindspore.common import mutable
|
|
20
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
21
|
+
from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
21
23
|
from .common import dtype as mstype
|
|
22
24
|
from . import context
|
|
23
25
|
from . import ops
|
|
24
26
|
from .ops import constexpr
|
|
25
|
-
from .common.api import jit_class
|
|
27
|
+
from .common.api import jit_class, jit
|
|
26
28
|
from .common.parameter import Parameter
|
|
27
29
|
from .common.tensor import Tensor
|
|
28
30
|
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
|
|
29
|
-
from .train.amp import build_train_network, auto_mixed_precision
|
|
31
|
+
from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision,\
|
|
32
|
+
get_white_list, get_black_list
|
|
30
33
|
|
|
31
34
|
|
|
32
35
|
_hypermap = ops.HyperMap()
|
|
@@ -51,46 +54,29 @@ def _grad_scale(scale, grad):
|
|
|
51
54
|
return grad * scale.astype(grad.dtype)
|
|
52
55
|
|
|
53
56
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
status = ops.isfinite(inputs)
|
|
58
|
-
return status.all()
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def init_status():
|
|
62
|
-
r"""
|
|
63
|
-
Returns a Tensor indicating initialized status for overflow detection.
|
|
64
|
-
|
|
65
|
-
Note:
|
|
66
|
-
Only Ascend need status to capture overflow status, you can also call
|
|
67
|
-
this function on GPU or CPU, but the return value is useless.
|
|
57
|
+
@jit
|
|
58
|
+
def _grad_scale_map(scale_value, inputs):
|
|
59
|
+
return _hypermap(_partial(_grad_scale, scale_value), inputs)
|
|
68
60
|
|
|
69
|
-
Returns:
|
|
70
|
-
Tensor, has the shape of `(8,)`.
|
|
71
61
|
|
|
72
|
-
|
|
73
|
-
|
|
62
|
+
@jit
|
|
63
|
+
def _grad_unscale_map(scale_value, inputs):
|
|
64
|
+
return _hypermap(_partial(_grad_unscale, scale_value), inputs)
|
|
74
65
|
|
|
75
|
-
Examples:
|
|
76
|
-
>>> status = amp.init_status()
|
|
77
|
-
"""
|
|
78
|
-
if _ascend_target():
|
|
79
|
-
status = ops.NPUAllocFloatStatus()()
|
|
80
|
-
clear_status = ops.NPUClearFloatStatus()(status)
|
|
81
|
-
status = ops.depend(status, clear_status)
|
|
82
|
-
else:
|
|
83
|
-
status = Tensor([0, 0, 0, 0, 0, 0, 0, 0], mstype.float32)
|
|
84
66
|
|
|
85
|
-
|
|
67
|
+
def _overflow(inputs):
|
|
68
|
+
if _gpu_target():
|
|
69
|
+
return ops.FloatStatus()(inputs)
|
|
70
|
+
status = ops.isfinite(inputs)
|
|
71
|
+
return 1 - status.all()
|
|
86
72
|
|
|
87
73
|
|
|
88
|
-
def all_finite(inputs
|
|
74
|
+
def all_finite(inputs):
|
|
89
75
|
r"""
|
|
90
76
|
Returns a scalar Tensor indicating whether the inputs are finite.
|
|
91
77
|
|
|
92
|
-
|
|
93
|
-
This is an experimental
|
|
78
|
+
.. warning::
|
|
79
|
+
This is an experimental API that is subject to change or deletion.
|
|
94
80
|
|
|
95
81
|
The interface must be used in whole network training scenario to detect
|
|
96
82
|
whether grads are finite, and the results may be different on different
|
|
@@ -98,8 +84,6 @@ def all_finite(inputs, status=None):
|
|
|
98
84
|
|
|
99
85
|
Args:
|
|
100
86
|
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
|
|
101
|
-
status (Tensor): the status Tensor for overflow detection, only required on
|
|
102
|
-
Ascend. Default: None.
|
|
103
87
|
|
|
104
88
|
Returns:
|
|
105
89
|
Tensor, a scalar Tensor and the dtype is bool.
|
|
@@ -112,16 +96,18 @@ def all_finite(inputs, status=None):
|
|
|
112
96
|
>>> output = amp.all_finite(x)
|
|
113
97
|
"""
|
|
114
98
|
if _ascend_target():
|
|
115
|
-
|
|
116
|
-
raise ValueError("The status must be initialized on Ascend, but get 'None'.")
|
|
99
|
+
status = Tensor([0] * 8, mstype.int32)
|
|
117
100
|
status = ops.depend(status, inputs)
|
|
118
|
-
get_status =
|
|
101
|
+
get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status)
|
|
119
102
|
status = ops.depend(status, get_status)
|
|
120
|
-
|
|
121
|
-
|
|
103
|
+
clear_status = _get_cache_prim(NPUClearFloatStatusV2)()(status)
|
|
104
|
+
get_status = ops.depend(get_status, clear_status)
|
|
105
|
+
status_finite = get_status.equal(Tensor(0, mstype.int32)).all()
|
|
122
106
|
return status_finite
|
|
123
|
-
outputs = _hypermap(_partial(
|
|
124
|
-
|
|
107
|
+
outputs = _hypermap(_partial(_overflow), inputs)
|
|
108
|
+
flag_sum = ops.addn(outputs).reshape(())
|
|
109
|
+
_all_finite = ops.less(flag_sum, 1)
|
|
110
|
+
return _all_finite
|
|
125
111
|
|
|
126
112
|
|
|
127
113
|
@jit_class
|
|
@@ -133,8 +119,11 @@ class LossScaler(ABC):
|
|
|
133
119
|
to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the
|
|
134
120
|
loss scale value.
|
|
135
121
|
|
|
136
|
-
|
|
137
|
-
|
|
122
|
+
For more information, refer to the `tutorials <https://mindspore.cn/tutorials/en/r2.0/advanced/
|
|
123
|
+
mixed_precision.html#loss-scaling>`_.
|
|
124
|
+
|
|
125
|
+
.. warning::
|
|
126
|
+
This is an experimental API that is subject to change or deletion.
|
|
138
127
|
"""
|
|
139
128
|
@abstractmethod
|
|
140
129
|
def scale(self, inputs):
|
|
@@ -173,8 +162,8 @@ class StaticLossScaler(LossScaler):
|
|
|
173
162
|
|
|
174
163
|
Scales and unscales loss or gradients by a fixed constant.
|
|
175
164
|
|
|
176
|
-
|
|
177
|
-
This is an experimental
|
|
165
|
+
.. warning::
|
|
166
|
+
This is an experimental API that is subject to change or deletion.
|
|
178
167
|
|
|
179
168
|
Args:
|
|
180
169
|
scale_value (Union(float, int)): The initial loss scale value.
|
|
@@ -211,7 +200,8 @@ class StaticLossScaler(LossScaler):
|
|
|
211
200
|
Returns:
|
|
212
201
|
Union(Tensor, tuple(Tensor)), the scaled value.
|
|
213
202
|
"""
|
|
214
|
-
|
|
203
|
+
inputs = mutable(inputs)
|
|
204
|
+
return _grad_scale_map(self.scale_value, inputs)
|
|
215
205
|
|
|
216
206
|
def unscale(self, inputs):
|
|
217
207
|
"""
|
|
@@ -223,7 +213,8 @@ class StaticLossScaler(LossScaler):
|
|
|
223
213
|
Returns:
|
|
224
214
|
Union(Tensor, tuple(Tensor)), the unscaled value.
|
|
225
215
|
"""
|
|
226
|
-
|
|
216
|
+
inputs = mutable(inputs)
|
|
217
|
+
return _grad_unscale_map(self.scale_value, inputs)
|
|
227
218
|
|
|
228
219
|
def adjust(self, grads_finite):
|
|
229
220
|
"""
|
|
@@ -244,8 +235,8 @@ class DynamicLossScaler(LossScaler):
|
|
|
244
235
|
`scale_window` steps by `factor` if the grads remain finite, otherwise it reduces
|
|
245
236
|
the loss scale by `1 / factor` and resets the counter.
|
|
246
237
|
|
|
247
|
-
|
|
248
|
-
This is an experimental
|
|
238
|
+
.. warning::
|
|
239
|
+
This is an experimental API that is subject to change or deletion.
|
|
249
240
|
|
|
250
241
|
Args:
|
|
251
242
|
scale_value (Union(float, int)): The initial loss scale value.
|
|
@@ -286,7 +277,8 @@ class DynamicLossScaler(LossScaler):
|
|
|
286
277
|
Returns:
|
|
287
278
|
Union(Tensor, tuple(Tensor)), the scaled value.
|
|
288
279
|
"""
|
|
289
|
-
|
|
280
|
+
inputs = mutable(inputs)
|
|
281
|
+
return _grad_scale_map(self.scale_value, inputs)
|
|
290
282
|
|
|
291
283
|
def unscale(self, inputs):
|
|
292
284
|
"""
|
|
@@ -298,8 +290,10 @@ class DynamicLossScaler(LossScaler):
|
|
|
298
290
|
Returns:
|
|
299
291
|
Union(Tensor, tuple(Tensor)), the unscaled value.
|
|
300
292
|
"""
|
|
301
|
-
|
|
293
|
+
inputs = mutable(inputs)
|
|
294
|
+
return _grad_unscale_map(self.scale_value, inputs)
|
|
302
295
|
|
|
296
|
+
@jit
|
|
303
297
|
def adjust(self, grads_finite):
|
|
304
298
|
"""
|
|
305
299
|
Adjust the `scale_value` dependent on whether grads are finite.
|
|
@@ -313,7 +307,7 @@ class DynamicLossScaler(LossScaler):
|
|
|
313
307
|
grads_finite,
|
|
314
308
|
ops.select(
|
|
315
309
|
self.counter == (self.scale_window - 1),
|
|
316
|
-
ops.select(
|
|
310
|
+
ops.select(ops.isfinite(scale_mul_factor),
|
|
317
311
|
scale_mul_factor,
|
|
318
312
|
self.scale_value),
|
|
319
313
|
self.scale_value),
|
|
@@ -327,5 +321,6 @@ class DynamicLossScaler(LossScaler):
|
|
|
327
321
|
__all__ = [
|
|
328
322
|
"DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
|
|
329
323
|
"build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
|
|
330
|
-
"auto_mixed_precision", "
|
|
324
|
+
"auto_mixed_precision", "all_finite", "custom_mixed_precision",
|
|
325
|
+
"get_white_list", "get_black_list"
|
|
331
326
|
]
|
mindspore/boost/boost.py
CHANGED
|
@@ -156,8 +156,8 @@ class AutoBoost:
|
|
|
156
156
|
|
|
157
157
|
Here:
|
|
158
158
|
|
|
159
|
-
- pca_mat (array): Shape (k*n)
|
|
160
|
-
- bk (array): Shape (k*k)
|
|
159
|
+
- pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
|
|
160
|
+
- bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
|
|
161
161
|
|
|
162
162
|
we need to find the m satisfy:
|
|
163
163
|
|
|
@@ -27,6 +27,7 @@ from mindspore.common import Tensor
|
|
|
27
27
|
from mindspore.common.sparse_tensor import RowTensorInner
|
|
28
28
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
29
29
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
30
|
+
from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
|
|
30
31
|
from mindspore.ops import functional as F
|
|
31
32
|
from mindspore.ops import composite as C
|
|
32
33
|
from mindspore.ops import operations as P
|
|
@@ -115,7 +116,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|
|
115
116
|
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
116
117
|
|
|
117
118
|
Inputs:
|
|
118
|
-
-
|
|
119
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
119
120
|
|
|
120
121
|
Outputs:
|
|
121
122
|
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
|
|
@@ -392,7 +393,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
392
393
|
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
|
|
393
394
|
|
|
394
395
|
Inputs:
|
|
395
|
-
-
|
|
396
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
396
397
|
|
|
397
398
|
Outputs:
|
|
398
399
|
Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
|
|
@@ -460,6 +461,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
460
461
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
461
462
|
self.gpu_target = (context.get_context("device_target") == "GPU")
|
|
462
463
|
self.loss_scaling_manager = None
|
|
464
|
+
self.base0 = Tensor(0, mstype.int32)
|
|
465
|
+
self.reduce_all = P.ReduceAll(keep_dims=False)
|
|
466
|
+
self.reduce_any = P.ReduceAny(keep_dims=False)
|
|
467
|
+
self.equal = P.Equal()
|
|
468
|
+
self.not_equal = P.NotEqual()
|
|
463
469
|
|
|
464
470
|
if self.auto_boost.boost_config.get("loss_scale_group", False):
|
|
465
471
|
self.enable_enhanced_amp = True
|
|
@@ -535,12 +541,13 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
535
541
|
bool, overflow value.
|
|
536
542
|
float, update ratio.
|
|
537
543
|
"""
|
|
538
|
-
flag_sum = self.
|
|
544
|
+
flag_sum = self.equal(self.base0, param)
|
|
539
545
|
if self.reducer_flag:
|
|
540
546
|
flag_reduce = self.allreduce(flag_sum)
|
|
541
|
-
overflow = self.
|
|
547
|
+
overflow = not self.reduce_all(flag_reduce)
|
|
542
548
|
else:
|
|
543
|
-
overflow = self.
|
|
549
|
+
overflow = not self.reduce_all(flag_sum)
|
|
550
|
+
|
|
544
551
|
if overflow:
|
|
545
552
|
update_ratio = self.reduce_ratio
|
|
546
553
|
else:
|
|
@@ -609,13 +616,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
609
616
|
The second value is the same as the input of `compute_input`, but contains some information about the
|
|
610
617
|
execution order.
|
|
611
618
|
"""
|
|
612
|
-
status =
|
|
619
|
+
status = Tensor([0] * 8, mstype.int32)
|
|
613
620
|
if not self.gpu_target:
|
|
614
|
-
# init overflow buffer
|
|
615
|
-
status = P.NPUAllocFloatStatus()()
|
|
616
621
|
status = F.depend(status, pre_cond)
|
|
617
622
|
# clear overflow buffer
|
|
618
|
-
clear_status =
|
|
623
|
+
clear_status = NPUClearFloatStatusV2()(status)
|
|
619
624
|
compute_input = F.depend(compute_input, clear_status)
|
|
620
625
|
return status, compute_input
|
|
621
626
|
|
|
@@ -636,22 +641,35 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
636
641
|
"""
|
|
637
642
|
if not self.gpu_target:
|
|
638
643
|
status = F.depend(status, compute_output)
|
|
639
|
-
get_status =
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
644
|
+
get_status = NPUGetFloatStatusV2()(status)
|
|
645
|
+
|
|
646
|
+
if self.is_distributed:
|
|
647
|
+
# sum overflow flag over devices
|
|
648
|
+
flag_reduce = self.allreduce(get_status)
|
|
649
|
+
# get_status not equal to [0]*8 means overflow
|
|
650
|
+
flag = self.not_equal(self.base0, flag_reduce)
|
|
651
|
+
status = F.depend(status, flag)
|
|
652
|
+
# distributed needs to skip allreduce to avoid its overflow affecting the next step
|
|
653
|
+
clear_status = NPUClearFloatStatusV2()(status)
|
|
654
|
+
flag = F.depend(flag, clear_status)
|
|
655
|
+
else:
|
|
656
|
+
status = F.depend(status, get_status)
|
|
657
|
+
clear_status = NPUClearFloatStatusV2()(status)
|
|
658
|
+
get_status = F.depend(get_status, clear_status)
|
|
659
|
+
flag = self.not_equal(self.base0, get_status)
|
|
660
|
+
overflow = self.reduce_any(flag)
|
|
643
661
|
else:
|
|
644
662
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
|
|
645
663
|
flag_sum = P.AddN()(flag_sum)
|
|
646
664
|
# convert flag_sum to scalar
|
|
647
665
|
flag_sum = P.Reshape()(flag_sum, (()))
|
|
648
666
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
667
|
+
if self.is_distributed:
|
|
668
|
+
# sum overflow flag over devices
|
|
669
|
+
flag_reduce = self.allreduce(flag_sum)
|
|
670
|
+
overflow = self.less_equal(self.base, flag_reduce)
|
|
671
|
+
else:
|
|
672
|
+
overflow = self.less_equal(self.base, flag_sum)
|
|
655
673
|
return overflow
|
|
656
674
|
|
|
657
675
|
def _process_loss_scale(self, overflow):
|
|
@@ -688,7 +706,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
|
|
|
688
706
|
self.optimizer_loss_scale = [self.parent.count(x) for x in parent_set]
|
|
689
707
|
self.reduce_ratio = Tensor(1.0 / (2 ** 0.5), mstype.float32)
|
|
690
708
|
self.growth_ratio = Tensor(2 ** (1.0 / 1000.0), mstype.float32)
|
|
691
|
-
self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.
|
|
709
|
+
self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.int32),
|
|
692
710
|
name='mix_layer_status_{}'.format(x), requires_grad=False)
|
|
693
711
|
for x in range(loss_scale_number))
|
|
694
712
|
self.loss_scaling_manager.set_loss_scale_status(loss_scale_number, self.loss_scaling_manager.get_loss_scale())
|
mindspore/boost/dim_reduce.py
CHANGED
|
@@ -102,8 +102,8 @@ class DimReduce(Cell):
|
|
|
102
102
|
|
|
103
103
|
Here:
|
|
104
104
|
|
|
105
|
-
- pca_mat (array): Shape (k*n)
|
|
106
|
-
- bk (array): Shape (k*k)
|
|
105
|
+
- pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
|
|
106
|
+
- bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
|
|
107
107
|
|
|
108
108
|
we need to find the m satisfy:
|
|
109
109
|
|
|
@@ -138,7 +138,7 @@ class DimReduce(Cell):
|
|
|
138
138
|
- **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
|
|
139
139
|
- **weight** (Tuple(Tensor)) - Tuple of parameters.
|
|
140
140
|
- **weight_clone** (Tuple(Tensor)) - clone of weight
|
|
141
|
-
-
|
|
141
|
+
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
142
142
|
|
|
143
143
|
Outputs:
|
|
144
144
|
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
|
@@ -93,7 +93,7 @@ class GroupLossScaleManager(Cell):
|
|
|
93
93
|
>>> boost_level="O1", boost_config_dict=boost_config_dict)
|
|
94
94
|
>>> # For details about how to build the dataset, please refer to the variable `dataset_train` in tutorial
|
|
95
95
|
>>> # document on the official website:
|
|
96
|
-
>>> # https://www.mindspore.cn/tutorials/zh-CN/r2.0
|
|
96
|
+
>>> # https://www.mindspore.cn/tutorials/zh-CN/r2.0/beginner/quick_start.html
|
|
97
97
|
>>> dataset = create_custom_dataset()
|
|
98
98
|
>>> model.train(2, dataset)
|
|
99
99
|
"""
|
mindspore/common/__init__.py
CHANGED
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
"""Top-level reference to dtype of common module."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from mindspore.common import dtype
|
|
18
|
-
from mindspore.common.api import
|
|
18
|
+
from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, jit, jit_class
|
|
19
19
|
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
|
20
20
|
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
|
21
21
|
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
|
22
22
|
uint, number, tensor, string, type_none, tensor_type, Int, \
|
|
23
23
|
complex64, complex128, dtype_to_nptype, _null, _null_type, \
|
|
24
|
-
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
|
|
24
|
+
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype, QuantDtype
|
|
25
25
|
from mindspore.common.dump import set_dump
|
|
26
26
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
27
27
|
from mindspore.common.seed import set_seed, get_seed
|
|
@@ -29,7 +29,6 @@ from mindspore.common.tensor import Tensor
|
|
|
29
29
|
from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
|
|
30
30
|
from mindspore.common.mutable import mutable
|
|
31
31
|
from mindspore.common.jit_config import JitConfig
|
|
32
|
-
from mindspore.common._utils import update_and_return_dict
|
|
33
32
|
|
|
34
33
|
# symbols from dtype
|
|
35
34
|
__all__ = [
|
|
@@ -50,7 +49,7 @@ __all__ = [
|
|
|
50
49
|
"number", "tensor",
|
|
51
50
|
"string", "type_none",
|
|
52
51
|
"_null",
|
|
53
|
-
"tensor_type",
|
|
52
|
+
"tensor_type", "QuantDtype",
|
|
54
53
|
"Type", "Int", "_null_type",
|
|
55
54
|
"complex64", "complex128",
|
|
56
55
|
# __method__ from dtype
|
|
@@ -60,12 +59,11 @@ __all__ = [
|
|
|
60
59
|
|
|
61
60
|
__all__.extend([
|
|
62
61
|
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
|
63
|
-
"
|
|
62
|
+
"ms_function", "ms_class", 'jit', 'jit_class', # api
|
|
64
63
|
"Parameter", "ParameterTuple", # parameter
|
|
65
64
|
"dtype",
|
|
66
65
|
"set_seed", "get_seed", # random seed
|
|
67
66
|
"set_dump",
|
|
68
67
|
"ms_memory_recycle",
|
|
69
68
|
"mutable", "JitConfig",
|
|
70
|
-
"update_and_return_dict",
|
|
71
69
|
])
|
mindspore/common/_decorator.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Providing decorators."""
|
|
16
16
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from functools import wraps
|
|
18
19
|
from mindspore import log
|
|
19
20
|
|
|
20
21
|
|
|
@@ -31,6 +32,7 @@ def deprecated(version, substitute, use_substitute_name=False):
|
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
34
|
def decorate(func):
|
|
35
|
+
@wraps(func)
|
|
34
36
|
def wrapper(*args, **kwargs):
|
|
35
37
|
cls = getattr(args[0], "__class__", None) if args else None
|
|
36
38
|
name = cls.__name__ if cls else func.__name__
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Registry MSAdapter config."""
|
|
17
|
+
|
|
18
|
+
from mindspore.common.tensor import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Registry:
|
|
22
|
+
"""Registry class for ms adapter."""
|
|
23
|
+
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self._tensor = None
|
|
26
|
+
self._convert_map = {}
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def tensor(self):
|
|
30
|
+
"""Return the registered tensor."""
|
|
31
|
+
if self._tensor is None:
|
|
32
|
+
raise ValueError("Before using Tensor in MSAdapter, please call 'set_adapter_config'.")
|
|
33
|
+
return self._tensor
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def convert_map(self):
|
|
37
|
+
"""Return the registered convert map."""
|
|
38
|
+
return self._convert_map
|
|
39
|
+
|
|
40
|
+
def register_tensor(self, value):
|
|
41
|
+
"""Register the tensor of ms adapter."""
|
|
42
|
+
if self._tensor is not None:
|
|
43
|
+
raise ValueError("Repeated registration of tensor in ms adapter config.")
|
|
44
|
+
if not issubclass(value, Tensor):
|
|
45
|
+
raise ValueError(f"The tensor definition here should be a subclass of ms.Tensor, but got {value}.")
|
|
46
|
+
self._tensor = value
|
|
47
|
+
|
|
48
|
+
def register_convert_map(self, value):
|
|
49
|
+
"""Register the convert map of ms adapter."""
|
|
50
|
+
if not isinstance(value, dict):
|
|
51
|
+
raise ValueError(f"Expect a dict type, but got {type(value)}.")
|
|
52
|
+
self._convert_map = value
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
ms_adapter_registry = Registry()
|