mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
- mindspore/__init__.py +4 -2
- mindspore/_akg/akg/composite/build_module.py +11 -0
- mindspore/_akg/akg/config/repository_cuda.json +11 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +52 -57
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/config/super_bar_config.json +512 -0
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libicudata.so.69 +0 -0
- mindspore/lib/libicui18n.so.69 +0 -0
- mindspore/lib/libicuuc.so.69 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/scipy/linalg.py +10 -114
- mindspore/scipy/ops.py +2 -2
- mindspore/scipy/ops_wrapper.py +1 -1
- mindspore/scipy/optimize/_bfgs.py +1 -1
- mindspore/scipy/optimize/_lagrange.py +200 -0
- mindspore/scipy/optimize/line_search.py +3 -2
- mindspore/scipy/optimize/minimize.py +41 -2
- mindspore/scipy/sparse/__init__.py +2 -2
- mindspore/scipy/sparse/linalg.py +4 -464
- mindspore/scipy/utils.py +1 -1
- mindspore/scipy/utils_const.py +7 -1
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""Spectral operators."""
|
|
17
|
-
from mindspore
|
|
17
|
+
from mindspore import _checkparam as validator
|
|
18
18
|
from mindspore.common import dtype as mstype
|
|
19
19
|
from mindspore.ops.primitive import Primitive, prim_attr_register
|
|
20
20
|
|
|
@@ -23,10 +23,26 @@ class BartlettWindow(Primitive):
|
|
|
23
23
|
r"""
|
|
24
24
|
Bartlett window function.
|
|
25
25
|
|
|
26
|
+
.. warning::
|
|
27
|
+
This is an experimental API that is subject to change or deletion.
|
|
28
|
+
|
|
26
29
|
Refer to :func:`mindspore.ops.bartlett_window` for more details.
|
|
27
30
|
|
|
31
|
+
Args:
|
|
32
|
+
periodic (bool, optional): If True, returns a window to be used as periodic function.
|
|
33
|
+
If False, return a symmetric window. Default: True.
|
|
34
|
+
dtype (mindspore.dtype, optional): The desired datatype of returned tensor.
|
|
35
|
+
Only float16, float32 and float64 are allowed. Default: mstype.float32.
|
|
36
|
+
|
|
37
|
+
Inputs:
|
|
38
|
+
- **window_length** (Tensor) - The size of returned window, with data type int32, int64.
|
|
39
|
+
The input data should be an integer with a value of [0, 1000000].
|
|
40
|
+
|
|
41
|
+
Outputs:
|
|
42
|
+
A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
|
|
43
|
+
|
|
28
44
|
Supported Platforms:
|
|
29
|
-
``GPU`` ``CPU``
|
|
45
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
30
46
|
|
|
31
47
|
Examples:
|
|
32
48
|
>>> window_length = Tensor(5, mstype.int32)
|
|
@@ -50,10 +66,26 @@ class BlackmanWindow(Primitive):
|
|
|
50
66
|
r"""
|
|
51
67
|
Blackman window function.
|
|
52
68
|
|
|
69
|
+
.. warning::
|
|
70
|
+
This is an experimental API that is subject to change or deletion.
|
|
71
|
+
|
|
53
72
|
Refer to :func:`mindspore.ops.blackman_window` for more details.
|
|
54
73
|
|
|
74
|
+
Args:
|
|
75
|
+
periodic (bool, optional): If True, returns a window to be used as periodic function.
|
|
76
|
+
If False, return a symmetric window. Default: True.
|
|
77
|
+
dtype (mindspore.dtype, optional): the desired data type of returned tensor.
|
|
78
|
+
Only float16, float32 and float64 is allowed. Default: mstype.float32.
|
|
79
|
+
|
|
80
|
+
Inputs:
|
|
81
|
+
- **window_length** (Tensor) - the size of returned window, with data type int32, int64.
|
|
82
|
+
The input data should be an integer with a value of [0, 1000000].
|
|
83
|
+
|
|
84
|
+
Outputs:
|
|
85
|
+
A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
|
|
86
|
+
|
|
55
87
|
Supported Platforms:
|
|
56
|
-
``GPU`` ``CPU``
|
|
88
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
57
89
|
|
|
58
90
|
Examples:
|
|
59
91
|
>>> window_length = Tensor(10, mindspore.int32)
|
mindspore/ops/primitive.py
CHANGED
|
@@ -24,8 +24,9 @@ from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_par
|
|
|
24
24
|
from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
|
|
25
25
|
from mindspore.common.parameter import Parameter
|
|
26
26
|
from mindspore.common.api import _pynative_executor
|
|
27
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
27
28
|
from mindspore._c_expression import Primitive_, prim_type
|
|
28
|
-
from mindspore
|
|
29
|
+
from mindspore import _checkparam as Validator
|
|
29
30
|
from mindspore.ops import signature as sig
|
|
30
31
|
|
|
31
32
|
|
|
@@ -486,10 +487,10 @@ class PrimitiveWithCheck(Primitive):
|
|
|
486
487
|
... def __init__(self):
|
|
487
488
|
... pass
|
|
488
489
|
... def check_shape(self, input_x):
|
|
489
|
-
...
|
|
490
|
+
... Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name)
|
|
490
491
|
...
|
|
491
492
|
... def check_dtype(self, input_x):
|
|
492
|
-
...
|
|
493
|
+
... Validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
|
|
493
494
|
...
|
|
494
495
|
>>> # init a Primitive obj
|
|
495
496
|
>>> add = Flatten()
|
|
@@ -501,10 +502,18 @@ class PrimitiveWithCheck(Primitive):
|
|
|
501
502
|
|
|
502
503
|
def __check__(self, *args):
|
|
503
504
|
"""Checking the input shape and the input type of ops is valid """
|
|
504
|
-
|
|
505
|
-
for
|
|
506
|
-
|
|
507
|
-
|
|
505
|
+
check_dtype_fn = getattr(self, 'check_dtype')
|
|
506
|
+
check_dtype_fn(*(x['dtype'] for x in args))
|
|
507
|
+
|
|
508
|
+
is_shape_known = True
|
|
509
|
+
for x in args:
|
|
510
|
+
shape = x['shape']
|
|
511
|
+
if shape is None or -1 in shape or -2 in shape:
|
|
512
|
+
is_shape_known = False
|
|
513
|
+
break
|
|
514
|
+
if is_shape_known:
|
|
515
|
+
check_shape_fn = getattr(self, 'check_shape')
|
|
516
|
+
check_shape_fn(*(x['shape'] for x in args))
|
|
508
517
|
|
|
509
518
|
def _clone(self):
|
|
510
519
|
"""
|
|
@@ -731,6 +740,24 @@ def prim_attr_register(fn):
|
|
|
731
740
|
return deco
|
|
732
741
|
|
|
733
742
|
|
|
743
|
+
def _check_contains_variable(item_dtype, item_value):
|
|
744
|
+
"""
|
|
745
|
+
Check whether the item is or contains variable.
|
|
746
|
+
"""
|
|
747
|
+
if isinstance(item_value, (list, tuple)):
|
|
748
|
+
for i, element in enumerate(item_value):
|
|
749
|
+
if _check_contains_variable(item_dtype[i], element):
|
|
750
|
+
return True
|
|
751
|
+
elif isinstance(item_value, dict):
|
|
752
|
+
for i in range(len(item_value)):
|
|
753
|
+
if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
|
|
754
|
+
return True
|
|
755
|
+
for i in range(len(item_value)):
|
|
756
|
+
if _check_contains_variable(item_dtype[i], list(item_value.values())[i]):
|
|
757
|
+
return True
|
|
758
|
+
return item_dtype is not None and item_value is None
|
|
759
|
+
|
|
760
|
+
|
|
734
761
|
def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
|
|
735
762
|
"""
|
|
736
763
|
Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
|
|
@@ -778,13 +805,14 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|
|
778
805
|
PrimitiveWithInfer.__init__(self, op_name)
|
|
779
806
|
self.set_const_prim(True)
|
|
780
807
|
self.fn = fn
|
|
808
|
+
self.add_prim_attr('constexpr_prim', True)
|
|
781
809
|
if not reuse_result:
|
|
782
810
|
self.add_prim_attr('forbid_reuse_result', True)
|
|
783
811
|
|
|
784
812
|
def __infer__(self, *args):
|
|
785
813
|
value_args = []
|
|
786
814
|
for item in args:
|
|
787
|
-
if (item["dtype"]
|
|
815
|
+
if _check_contains_variable(item["dtype"], item["value"]) and check:
|
|
788
816
|
logger.warning("The \"" + self.name + "\" is a constexpr function." \
|
|
789
817
|
" The input arguments must be all constant value.")
|
|
790
818
|
value_args.append(item["value"])
|
|
@@ -802,8 +830,82 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|
|
802
830
|
return deco
|
|
803
831
|
|
|
804
832
|
|
|
805
|
-
|
|
833
|
+
def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
|
|
834
|
+
"""
|
|
835
|
+
_primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains
|
|
836
|
+
variable, the function will be compiled as graph.
|
|
837
|
+
|
|
838
|
+
_primexpr is only for internal use.
|
|
839
|
+
|
|
840
|
+
Args:
|
|
841
|
+
fn (function): A `fn` use as the infer_value of the output operator. Default: None.
|
|
842
|
+
get_instance (bool): If true, return the instance of operator,
|
|
843
|
+
otherwise return the operator class. Default: True.
|
|
844
|
+
name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
|
|
845
|
+
reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
|
|
846
|
+
otherwise the operator will always be executed. Default: True.
|
|
847
|
+
"""
|
|
848
|
+
def deco(fn):
|
|
849
|
+
"""Decorator for CompileOp."""
|
|
850
|
+
|
|
851
|
+
class CompileOp(PrimitiveWithInfer):
|
|
852
|
+
"""
|
|
853
|
+
CompileOp is a temporary operator used to execute the constexpr function.
|
|
854
|
+
"""
|
|
855
|
+
|
|
856
|
+
def __init__(self):
|
|
857
|
+
op_name = name if name else fn.__name__
|
|
858
|
+
PrimitiveWithInfer.__init__(self, op_name)
|
|
859
|
+
self.set_const_prim(True)
|
|
860
|
+
self.fn = fn
|
|
861
|
+
self.add_prim_attr('constexpr_prim', True)
|
|
862
|
+
if not reuse_result:
|
|
863
|
+
self.add_prim_attr('forbid_reuse_result', True)
|
|
864
|
+
|
|
865
|
+
def __infer__(self, *args):
|
|
866
|
+
value_args = []
|
|
867
|
+
for item in args:
|
|
868
|
+
if _check_contains_variable(item["dtype"], item["value"]):
|
|
869
|
+
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
|
|
870
|
+
value_args.append(item["value"])
|
|
871
|
+
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
|
|
872
|
+
|
|
873
|
+
def __call__(self, *args, **kwargs):
|
|
874
|
+
return fn(*args, **kwargs)
|
|
875
|
+
|
|
876
|
+
if get_instance:
|
|
877
|
+
return CompileOp()
|
|
878
|
+
return CompileOp
|
|
879
|
+
|
|
880
|
+
if fn is not None:
|
|
881
|
+
return deco(fn)
|
|
882
|
+
return deco
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
_RUN_OP_ASYNC = True
|
|
886
|
+
|
|
887
|
+
|
|
806
888
|
def _run_op(obj, op_name, args):
|
|
807
889
|
"""Single op execution function supported by ge in PyNative mode."""
|
|
890
|
+
if _RUN_OP_ASYNC:
|
|
891
|
+
stub = _pynative_executor.run_op_async(obj, args)
|
|
892
|
+
return _convert_stub(stub)
|
|
893
|
+
return _run_op_sync(obj, op_name, args)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
@_wrap_func
|
|
897
|
+
def _run_op_sync(obj, op_name, args):
|
|
898
|
+
"""Single op execution function in synchronous mode."""
|
|
808
899
|
output = _pynative_executor.real_run_op(obj, op_name, args)
|
|
809
900
|
return output
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
class _PrimitiveC(Primitive):
|
|
904
|
+
def __init__(self, name, attrs):
|
|
905
|
+
super().__init__(name)
|
|
906
|
+
for key, value in attrs.items():
|
|
907
|
+
super().add_prim_attr(key, value)
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def _get_primitivec(name, attrs):
|
|
911
|
+
return _PrimitiveC(name, attrs)
|
|
@@ -15,13 +15,15 @@
|
|
|
15
15
|
"""Context of auto parallel"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import os
|
|
18
|
+
import copy
|
|
18
19
|
import threading
|
|
19
20
|
from mindspore import context
|
|
20
21
|
import mindspore.log as logger
|
|
21
22
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
|
22
23
|
from mindspore.parallel._ps_context import _is_role_pserver
|
|
23
24
|
from mindspore._c_expression import AutoParallelContext
|
|
24
|
-
from mindspore._checkparam import args_type_check
|
|
25
|
+
from mindspore._checkparam import args_type_check
|
|
26
|
+
from mindspore import _checkparam as Validator
|
|
25
27
|
|
|
26
28
|
_MAX_GROUP_NAME_LEN = 127
|
|
27
29
|
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
|
@@ -40,6 +42,18 @@ class _ParallelFusionConfig:
|
|
|
40
42
|
AUTO = "auto"
|
|
41
43
|
INDEX = "index"
|
|
42
44
|
SIZE = "size"
|
|
45
|
+
OPENSTATE = "openstate"
|
|
46
|
+
CONFIG = {"openstate": True,
|
|
47
|
+
"allreduce": {"mode": "auto", "config": None},
|
|
48
|
+
"allgather": {"mode": "auto", "config": None},
|
|
49
|
+
"reducescatter": {"mode": "auto", "config": None}}
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def reset(cls):
|
|
53
|
+
cls.CONFIG = {"openstate": True,
|
|
54
|
+
"allreduce": {"mode": "auto", "config": None},
|
|
55
|
+
"allgather": {"mode": "auto", "config": None},
|
|
56
|
+
"reducescatter": {"mode": "auto", "config": None}}
|
|
43
57
|
|
|
44
58
|
|
|
45
59
|
class _ParallelOptimizerConfig:
|
|
@@ -117,6 +131,9 @@ class _AutoParallelContext:
|
|
|
117
131
|
KeyError: When key of comm_fusion is not 'allreduce'.
|
|
118
132
|
"""
|
|
119
133
|
self.check_context_handle()
|
|
134
|
+
config = copy.deepcopy(config)
|
|
135
|
+
if _ParallelFusionConfig.OPENSTATE not in config.keys():
|
|
136
|
+
config[_ParallelFusionConfig.OPENSTATE] = True
|
|
120
137
|
for key in list(config.keys()):
|
|
121
138
|
if key == _ParallelFusionConfig.ALLREDUCE:
|
|
122
139
|
self._set_allreduce_comm_fusion(config[key])
|
|
@@ -124,91 +141,18 @@ class _AutoParallelContext:
|
|
|
124
141
|
self._set_allgather_comm_fusion(config[key], key)
|
|
125
142
|
elif key == _ParallelFusionConfig.REDUCESCATTER:
|
|
126
143
|
self._set_allgather_comm_fusion(config[key], key)
|
|
144
|
+
elif key == _ParallelFusionConfig.OPENSTATE:
|
|
145
|
+
self._set_openstate_comm_fusion(config[key])
|
|
127
146
|
else:
|
|
128
|
-
raise KeyError("comm fusion type must be
|
|
147
|
+
raise KeyError("comm fusion type must be openstate,"
|
|
148
|
+
"allreduce, allgather or reducescatter, but got {}".format(key))
|
|
149
|
+
if key in _ParallelFusionConfig.CONFIG:
|
|
150
|
+
_ParallelFusionConfig.CONFIG[key] = config[key]
|
|
129
151
|
|
|
130
152
|
def get_comm_fusion(self):
|
|
131
153
|
"""Get comm fusion config."""
|
|
132
154
|
self.check_context_handle()
|
|
133
|
-
|
|
134
|
-
if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE):
|
|
135
|
-
config = self.fusion_threshold_mb()
|
|
136
|
-
if mode == _ParallelFusionConfig.INDEX:
|
|
137
|
-
config = self.get_all_reduce_fusion_split_indices()
|
|
138
|
-
return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
|
|
139
|
-
_ParallelFusionConfig.FUSION_CONFIG: config}}
|
|
140
|
-
|
|
141
|
-
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
|
142
|
-
"""
|
|
143
|
-
Set allgather and reducescatter fusion method for auto parallel.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
|
147
|
-
supports four fusion methods: `auto` and `size`.
|
|
148
|
-
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
|
149
|
-
|
|
150
|
-
Raises:
|
|
151
|
-
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
|
152
|
-
KeyError: When `mode` is not 'auto', 'size'.
|
|
153
|
-
"""
|
|
154
|
-
self.check_context_handle()
|
|
155
|
-
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
|
156
|
-
return
|
|
157
|
-
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
|
158
|
-
return
|
|
159
|
-
if not isinstance(comm_fusion, dict):
|
|
160
|
-
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
|
161
|
-
comm_type, type(comm_fusion)))
|
|
162
|
-
if _ParallelFusionConfig.MODE not in comm_fusion:
|
|
163
|
-
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
|
164
|
-
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
|
165
|
-
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
|
166
|
-
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
|
167
|
-
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
|
168
|
-
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
|
169
|
-
else:
|
|
170
|
-
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
|
171
|
-
comm_fusion[_ParallelFusionConfig.MODE]))
|
|
172
|
-
|
|
173
|
-
fusion_threshold = 64
|
|
174
|
-
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
|
|
175
|
-
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
|
176
|
-
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
|
177
|
-
|
|
178
|
-
def _set_allreduce_comm_fusion(self, comm_fusion):
|
|
179
|
-
"""
|
|
180
|
-
Set fusion method for auto parallel.
|
|
181
|
-
|
|
182
|
-
Args:
|
|
183
|
-
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
|
184
|
-
supports four fusion methods: `auto`, `size` and `index`.
|
|
185
|
-
|
|
186
|
-
Raises:
|
|
187
|
-
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
|
188
|
-
KeyError: When `mode` is not 'auto', 'size' or 'index'.
|
|
189
|
-
"""
|
|
190
|
-
self.check_context_handle()
|
|
191
|
-
if not self.get_enable_all_reduce_fusion():
|
|
192
|
-
return
|
|
193
|
-
if not isinstance(comm_fusion, dict):
|
|
194
|
-
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
|
195
|
-
type(comm_fusion)))
|
|
196
|
-
if _ParallelFusionConfig.MODE not in comm_fusion:
|
|
197
|
-
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
|
198
|
-
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
|
199
|
-
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
|
200
|
-
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
|
|
201
|
-
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
|
202
|
-
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
|
203
|
-
else:
|
|
204
|
-
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
|
|
205
|
-
comm_fusion[_ParallelFusionConfig.MODE]))
|
|
206
|
-
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
|
|
207
|
-
self.set_fusion_threshold_mb(fusion_threshold=64)
|
|
208
|
-
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
|
|
209
|
-
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
|
210
|
-
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
|
211
|
-
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
|
155
|
+
return _ParallelFusionConfig.CONFIG
|
|
212
156
|
|
|
213
157
|
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
|
214
158
|
"""
|
|
@@ -521,6 +465,9 @@ class _AutoParallelContext:
|
|
|
521
465
|
if not isinstance(dim, int):
|
|
522
466
|
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
|
523
467
|
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
|
|
468
|
+
if context.get_context('mode') == context.PYNATIVE_MODE:
|
|
469
|
+
raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
|
|
470
|
+
f"or 'data_parallel', but got {dataset_strategy}.")
|
|
524
471
|
self._dataset_strategy_using_str = False
|
|
525
472
|
self._context_handle.set_dataset_strategy(dataset_strategy)
|
|
526
473
|
|
|
@@ -531,7 +478,11 @@ class _AutoParallelContext:
|
|
|
531
478
|
if self._context_handle.get_full_batch():
|
|
532
479
|
return "full_batch"
|
|
533
480
|
return "data_parallel"
|
|
534
|
-
|
|
481
|
+
dataset_strategy = self._context_handle.get_dataset_strategy()
|
|
482
|
+
if context.get_context('mode') == context.PYNATIVE_MODE:
|
|
483
|
+
raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
|
|
484
|
+
f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
|
|
485
|
+
return dataset_strategy
|
|
535
486
|
|
|
536
487
|
def set_grad_accumulation_step(self, grad_accumulation_step):
|
|
537
488
|
"""
|
|
@@ -567,6 +518,52 @@ class _AutoParallelContext:
|
|
|
567
518
|
self.check_context_handle()
|
|
568
519
|
return self._context_handle.get_strategy_ckpt_save_file()
|
|
569
520
|
|
|
521
|
+
def set_strategy_ckpt_config(self, strategy_ckpt_config):
|
|
522
|
+
"""
|
|
523
|
+
Set strategy checkpoint config.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
strategy_ckpt_config (dict): The strategy checkpoint config.
|
|
527
|
+
"""
|
|
528
|
+
self.check_context_handle()
|
|
529
|
+
if not isinstance(strategy_ckpt_config, dict):
|
|
530
|
+
raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
|
|
531
|
+
"must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
|
|
532
|
+
for config_name in strategy_ckpt_config:
|
|
533
|
+
unknown_config = []
|
|
534
|
+
if config_name not in ["load_file", "save_file", "only_trainable_params"]:
|
|
535
|
+
unknown_config.append(config_name)
|
|
536
|
+
|
|
537
|
+
if unknown_config:
|
|
538
|
+
raise ValueError("Unknown config: {}".format(unknown_config))
|
|
539
|
+
if "load_file" in strategy_ckpt_config:
|
|
540
|
+
load_file = strategy_ckpt_config.get("load_file")
|
|
541
|
+
if not isinstance(load_file, str):
|
|
542
|
+
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
|
543
|
+
"the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
|
|
544
|
+
self._context_handle.set_strategy_ckpt_load_file(load_file)
|
|
545
|
+
if "save_file" in strategy_ckpt_config:
|
|
546
|
+
save_file = strategy_ckpt_config.get("save_file")
|
|
547
|
+
if not isinstance(save_file, str):
|
|
548
|
+
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
|
549
|
+
"the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
|
|
550
|
+
self._context_handle.set_strategy_ckpt_save_file(save_file)
|
|
551
|
+
if "only_trainable_params" in strategy_ckpt_config:
|
|
552
|
+
only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
|
|
553
|
+
if not isinstance(only_trainable_params, bool):
|
|
554
|
+
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
|
555
|
+
"the argument 'only_trainable_params' must be bool,"
|
|
556
|
+
" but got the type : {} .".format(type(only_trainable_params)))
|
|
557
|
+
self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
|
|
558
|
+
|
|
559
|
+
def get_strategy_ckpt_config(self):
|
|
560
|
+
"""Get strategy checkpoint config."""
|
|
561
|
+
self.check_context_handle()
|
|
562
|
+
load_file = self._context_handle.get_strategy_ckpt_load_file()
|
|
563
|
+
save_file = self._context_handle.get_strategy_ckpt_save_file()
|
|
564
|
+
only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
|
|
565
|
+
return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
|
|
566
|
+
|
|
570
567
|
def set_group_ckpt_save_file(self, group_ckpt_save_file):
|
|
571
568
|
"""Set group checkpoint save path."""
|
|
572
569
|
self.check_context_handle()
|
|
@@ -912,6 +909,7 @@ class _AutoParallelContext:
|
|
|
912
909
|
return self._context_handle.get_optimizer_weight_shard_aggregated_save()
|
|
913
910
|
|
|
914
911
|
def get_full_batch_is_set(self):
|
|
912
|
+
"""Get full batch attr"""
|
|
915
913
|
self.check_context_handle()
|
|
916
914
|
return self._context_handle.get_full_batch_is_set()
|
|
917
915
|
|
|
@@ -919,6 +917,7 @@ class _AutoParallelContext:
|
|
|
919
917
|
"""Reset all settings."""
|
|
920
918
|
self.check_context_handle()
|
|
921
919
|
self._context_handle.reset()
|
|
920
|
+
_ParallelFusionConfig.reset()
|
|
922
921
|
|
|
923
922
|
def _check_and_default_group(self, group):
|
|
924
923
|
"""Validate the given group, if group is empty, returns a default fusion group"""
|
|
@@ -936,6 +935,99 @@ class _AutoParallelContext:
|
|
|
936
935
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
937
936
|
return group
|
|
938
937
|
|
|
938
|
+
def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
|
|
939
|
+
"""
|
|
940
|
+
Set allgather and reducescatter fusion method for auto parallel.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
|
944
|
+
supports four fusion methods: `auto` and `size`.
|
|
945
|
+
comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
|
|
946
|
+
|
|
947
|
+
Raises:
|
|
948
|
+
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
|
949
|
+
KeyError: When `mode` is not 'auto', 'size'.
|
|
950
|
+
"""
|
|
951
|
+
self.check_context_handle()
|
|
952
|
+
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
|
953
|
+
return
|
|
954
|
+
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
|
955
|
+
return
|
|
956
|
+
if not isinstance(comm_fusion, dict):
|
|
957
|
+
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
|
958
|
+
comm_type, type(comm_fusion)))
|
|
959
|
+
if _ParallelFusionConfig.MODE not in comm_fusion:
|
|
960
|
+
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
|
961
|
+
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
|
962
|
+
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
|
963
|
+
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
|
|
964
|
+
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
|
965
|
+
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
|
966
|
+
else:
|
|
967
|
+
raise KeyError("fusion method mode must be auto or size, but got {}".format(
|
|
968
|
+
comm_fusion[_ParallelFusionConfig.MODE]))
|
|
969
|
+
|
|
970
|
+
fusion_threshold = 64
|
|
971
|
+
if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
|
|
972
|
+
fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
|
|
973
|
+
self.set_fusion_threshold_mb(fusion_threshold, comm_type)
|
|
974
|
+
|
|
975
|
+
def _set_allreduce_comm_fusion(self, comm_fusion):
|
|
976
|
+
"""
|
|
977
|
+
Set fusion method for auto parallel.
|
|
978
|
+
|
|
979
|
+
Args:
|
|
980
|
+
comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
|
|
981
|
+
supports four fusion methods: `auto`, `size` and `index`.
|
|
982
|
+
|
|
983
|
+
Raises:
|
|
984
|
+
KeyError: When key of comm_fusion is not 'mode' or 'config'.
|
|
985
|
+
KeyError: When `mode` is not 'auto', 'size' or 'index'.
|
|
986
|
+
"""
|
|
987
|
+
self.check_context_handle()
|
|
988
|
+
if not self.get_enable_all_reduce_fusion():
|
|
989
|
+
return
|
|
990
|
+
if not isinstance(comm_fusion, dict):
|
|
991
|
+
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
|
992
|
+
type(comm_fusion)))
|
|
993
|
+
if _ParallelFusionConfig.MODE not in comm_fusion:
|
|
994
|
+
raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
|
|
995
|
+
if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
|
|
996
|
+
raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
|
|
997
|
+
check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
|
|
998
|
+
if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
|
|
999
|
+
self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
|
|
1000
|
+
else:
|
|
1001
|
+
raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
|
|
1002
|
+
comm_fusion[_ParallelFusionConfig.MODE]))
|
|
1003
|
+
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
|
|
1004
|
+
self.set_fusion_threshold_mb(fusion_threshold=64)
|
|
1005
|
+
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
|
|
1006
|
+
self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
|
1007
|
+
if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
|
|
1008
|
+
self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
|
|
1009
|
+
|
|
1010
|
+
def _set_openstate_comm_fusion(self, openstate):
|
|
1011
|
+
"""
|
|
1012
|
+
Set open state for comm fusion.
|
|
1013
|
+
|
|
1014
|
+
Args:
|
|
1015
|
+
openstate (bool): The open state value to set the fusion method whether or not. Currently it
|
|
1016
|
+
supports two states: `True`, or `Flase`.
|
|
1017
|
+
|
|
1018
|
+
Raises:
|
|
1019
|
+
TypeError: When the value is not bool.
|
|
1020
|
+
"""
|
|
1021
|
+
self.check_context_handle()
|
|
1022
|
+
if not isinstance(openstate, bool):
|
|
1023
|
+
raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
|
|
1024
|
+
type(openstate)))
|
|
1025
|
+
if not openstate:
|
|
1026
|
+
self.set_enable_all_reduce_fusion(openstate)
|
|
1027
|
+
self.set_enable_all_gather_fusion(openstate)
|
|
1028
|
+
self.set_enable_reduce_scatter_fusion(openstate)
|
|
1029
|
+
|
|
1030
|
+
|
|
939
1031
|
|
|
940
1032
|
_AUTO_PARALLEL_CONTEXT = None
|
|
941
1033
|
|
|
@@ -978,6 +1070,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
978
1070
|
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
|
979
1071
|
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
|
980
1072
|
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
|
|
1073
|
+
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
|
981
1074
|
"comm_fusion": auto_parallel_context().set_comm_fusion}
|
|
982
1075
|
|
|
983
1076
|
|
|
@@ -1005,6 +1098,7 @@ _get_auto_parallel_context_func_map = {
|
|
|
1005
1098
|
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
|
1006
1099
|
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
|
|
1007
1100
|
"comm_fusion": auto_parallel_context().get_comm_fusion,
|
|
1101
|
+
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
|
|
1008
1102
|
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
|
|
1009
1103
|
|
|
1010
1104
|
|
|
@@ -1014,7 +1108,8 @@ _get_auto_parallel_context_func_map = {
|
|
|
1014
1108
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
|
1015
1109
|
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
|
1016
1110
|
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
|
1017
|
-
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict
|
|
1111
|
+
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
|
|
1112
|
+
strategy_ckpt_config=dict)
|
|
1018
1113
|
|
|
1019
1114
|
def _set_auto_parallel_context(**kwargs):
|
|
1020
1115
|
"""
|
|
@@ -1091,12 +1186,23 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
1091
1186
|
communication fusion config has two keys: "mode" and "config".
|
|
1092
1187
|
It supports following communication fusion types and configurations:
|
|
1093
1188
|
|
|
1189
|
+
- openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
|
|
1190
|
+
the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
|
|
1191
|
+
|
|
1094
1192
|
- allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
|
|
1095
1193
|
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
|
|
1096
1194
|
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
|
|
1097
1195
|
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
|
|
1098
1196
|
`all_reduce_fusion_config`.
|
|
1099
1197
|
|
|
1198
|
+
- allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
|
|
1199
|
+
In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
|
|
1200
|
+
threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
|
|
1201
|
+
manually, and the fusion threshold must be larger than `0` MB.
|
|
1202
|
+
|
|
1203
|
+
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
|
|
1204
|
+
and `size`. Config is same as `allgather`.
|
|
1205
|
+
|
|
1100
1206
|
|
|
1101
1207
|
Raises:
|
|
1102
1208
|
ValueError: If input key is not attribute in auto parallel context.
|