mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
import json
|
|
19
20
|
import numpy as np
|
|
20
21
|
import mindspore as ms
|
|
21
22
|
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
@@ -39,7 +40,15 @@ def _convert_to_list(strategy):
|
|
|
39
40
|
field_size = int(layout.field)
|
|
40
41
|
shard_stride = int(layout.opt_weight_shard_step)
|
|
41
42
|
shard_size = int(layout.opt_weight_shard_size)
|
|
42
|
-
|
|
43
|
+
pipeline_stage = 0
|
|
44
|
+
origin_param_name = param_name
|
|
45
|
+
if "-" in param_name:
|
|
46
|
+
pipeline_stage, origin_param_name = param_name.split("-")
|
|
47
|
+
if origin_param_name not in train_map:
|
|
48
|
+
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride,
|
|
49
|
+
shard_size, [int(pipeline_stage)]]
|
|
50
|
+
else:
|
|
51
|
+
train_map.get(origin_param_name)[6].append(int(pipeline_stage))
|
|
43
52
|
except BaseException as e:
|
|
44
53
|
raise ValueError(f"{e.__str__()}. Convert layout strategy to list "
|
|
45
54
|
f"failed, please make sure that strategy matches the node_strategy.proto, you can "
|
|
@@ -73,8 +82,8 @@ def _convert_to_layout(param_name, tensor_layout):
|
|
|
73
82
|
return strategy
|
|
74
83
|
|
|
75
84
|
|
|
76
|
-
def
|
|
77
|
-
"""
|
|
85
|
+
def _check_strategy_file(strategy_filename):
|
|
86
|
+
"""load parallel strategy file"""
|
|
78
87
|
if not isinstance(strategy_filename, str):
|
|
79
88
|
raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
|
|
80
89
|
f"but got {type(strategy_filename)}.")
|
|
@@ -86,12 +95,25 @@ def _build_searched_strategy(strategy_filename):
|
|
|
86
95
|
if os.path.getsize(strategy_filename) == 0:
|
|
87
96
|
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} should not "
|
|
88
97
|
f"be empty. Please check whether the 'strategy_filename' is correct.")
|
|
89
|
-
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
|
90
98
|
|
|
99
|
+
|
|
100
|
+
def _load_protobuf_strategy(strategy_filename):
|
|
101
|
+
"""load strategy from protobuf file"""
|
|
102
|
+
parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
|
91
103
|
with open(strategy_filename, 'rb') as f:
|
|
92
104
|
pb_content = f.read()
|
|
93
|
-
|
|
105
|
+
try:
|
|
106
|
+
parallel_strategy_map.ParseFromString(pb_content)
|
|
107
|
+
except BaseException as e:
|
|
108
|
+
raise TypeError("The strategy file type should be one of json or protobuf. "
|
|
109
|
+
"When the file name extension is not '.json', "
|
|
110
|
+
"the file is considered as a protobuf file.") from e
|
|
111
|
+
return parallel_strategy_map
|
|
94
112
|
|
|
113
|
+
|
|
114
|
+
def _build_protobuf_strategy(strategy_filename):
|
|
115
|
+
"""build strategy from protobuf file"""
|
|
116
|
+
parallel_strategy_map = _load_protobuf_strategy(strategy_filename)
|
|
95
117
|
layout_items = parallel_strategy_map.parallel_layout_item
|
|
96
118
|
if not layout_items:
|
|
97
119
|
raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} has no sliced "
|
|
@@ -102,10 +124,143 @@ def _build_searched_strategy(strategy_filename):
|
|
|
102
124
|
parameter_name = layout_item.param_name
|
|
103
125
|
layout = layout_item.parallel_layouts
|
|
104
126
|
strategy[parameter_name] = layout
|
|
127
|
+
return strategy
|
|
128
|
+
|
|
105
129
|
|
|
130
|
+
def _build_json_strategy(strategy_filename):
|
|
131
|
+
"""build strategy from json file"""
|
|
132
|
+
with open(strategy_filename, 'r') as f:
|
|
133
|
+
json_content = json.load(f)
|
|
134
|
+
layout_items = json_content.get("parallel_layout_item")
|
|
135
|
+
strategy = {}
|
|
136
|
+
for parameter_name, layout_item in layout_items.items():
|
|
137
|
+
layout = ms.train.node_strategy_pb2.ParallelLayouts()
|
|
138
|
+
layout.field = layout_item.get("field")
|
|
139
|
+
layout.opt_weight_shard_size = layout_item.get("opt_weight_shard_size")
|
|
140
|
+
layout.opt_weight_shard_step = layout_item.get("opt_weight_shard_step")
|
|
141
|
+
dev_matrix = layout.dev_matrix.add()
|
|
142
|
+
for item in layout_item.get("dev_matrix"):
|
|
143
|
+
dev_matrix.dim.append(item)
|
|
144
|
+
tensor_map = layout.tensor_map.add()
|
|
145
|
+
for item in layout_item.get("tensor_map"):
|
|
146
|
+
tensor_map.dim.append(item)
|
|
147
|
+
param_split_shape = layout.param_split_shape.add()
|
|
148
|
+
if "param_split_shape" in layout_item:
|
|
149
|
+
for item in layout_item.get("param_split_shape"):
|
|
150
|
+
param_split_shape.dim.append(item)
|
|
151
|
+
indices_offset = layout.indices_offset.add()
|
|
152
|
+
if "indices_offset" in layout_item:
|
|
153
|
+
for item in layout_item.get("indices_offset"):
|
|
154
|
+
indices_offset.dim.append(item)
|
|
155
|
+
strategy[parameter_name] = layout
|
|
106
156
|
return strategy
|
|
107
157
|
|
|
108
158
|
|
|
159
|
+
def _build_searched_strategy(strategy_filename):
|
|
160
|
+
"""build searched strategy"""
|
|
161
|
+
_check_strategy_file(strategy_filename)
|
|
162
|
+
if strategy_filename[-5:] != ".json":
|
|
163
|
+
return _build_protobuf_strategy(strategy_filename)
|
|
164
|
+
return _build_json_strategy(strategy_filename)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file):
|
|
168
|
+
"""merge protobuf strategy"""
|
|
169
|
+
dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
|
|
170
|
+
merged_stage = []
|
|
171
|
+
for src_strategy_file in src_strategy_files:
|
|
172
|
+
src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
|
|
173
|
+
strategy_items = src_parallel_strategy_map.parallel_strategy_item
|
|
174
|
+
layout_items = src_parallel_strategy_map.parallel_layout_item
|
|
175
|
+
if not strategy_items or not layout_items:
|
|
176
|
+
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
|
|
177
|
+
pipeline_stage = strategy_items[0].parallel_strategys.stage
|
|
178
|
+
if pipeline_stage in merged_stage:
|
|
179
|
+
continue
|
|
180
|
+
for layout_item in layout_items:
|
|
181
|
+
layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
|
|
182
|
+
dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
|
|
183
|
+
dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
|
|
184
|
+
merged_stage.append(pipeline_stage)
|
|
185
|
+
dst_parallel_strategy_map.current_stage = 1
|
|
186
|
+
with open(dst_strategy_file, "wb") as f:
|
|
187
|
+
f.write(dst_parallel_strategy_map.SerializeToString())
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _merge_json_strategy(src_strategy_files, dst_strategy_file):
|
|
191
|
+
"""merge protobuf strategy"""
|
|
192
|
+
dst_parallel_strategy_map = {"current_stage": 1, "parallel_strategy_item": {}, "parallel_layout_item": {}}
|
|
193
|
+
merged_stage = []
|
|
194
|
+
for src_strategy_file in src_strategy_files:
|
|
195
|
+
with open(src_strategy_file, 'r') as f:
|
|
196
|
+
json_content = json.load(f)
|
|
197
|
+
layout_items = json_content.get("parallel_layout_item")
|
|
198
|
+
strategy_items = json_content.get("parallel_strategy_item")
|
|
199
|
+
if not strategy_items or not layout_items:
|
|
200
|
+
raise ValueError("The strategy file {} is empty".format(src_strategy_file))
|
|
201
|
+
pipeline_stage = strategy_items.get(list(strategy_items.keys())[0]).get('stage')
|
|
202
|
+
if pipeline_stage in merged_stage:
|
|
203
|
+
continue
|
|
204
|
+
for param_name, layout_item in layout_items.items():
|
|
205
|
+
new_layout_item = {}
|
|
206
|
+
new_param_name = "-".join([str(pipeline_stage), param_name])
|
|
207
|
+
new_layout_item[new_param_name] = layout_item
|
|
208
|
+
dst_parallel_strategy_map.get("parallel_layout_item").update(new_layout_item)
|
|
209
|
+
dst_parallel_strategy_map.get("parallel_strategy_item").update(strategy_items)
|
|
210
|
+
merged_stage.append(pipeline_stage)
|
|
211
|
+
with open(dst_strategy_file, "w") as f:
|
|
212
|
+
json.dump(dst_parallel_strategy_map, f)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _parameter_not_in_local_stage(param_name, origin_strategy_list, strategy_list):
|
|
216
|
+
"""parameter whether in the local stage"""
|
|
217
|
+
if origin_strategy_list is None or strategy_list is None:
|
|
218
|
+
return True
|
|
219
|
+
return param_name in origin_strategy_list and param_name not in strategy_list
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _extract_layout_map(strategy_file):
|
|
223
|
+
"""Extract layout map"""
|
|
224
|
+
layout_map = None
|
|
225
|
+
if strategy_file is not None:
|
|
226
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
227
|
+
layout_map = _convert_to_list(src_strategy)
|
|
228
|
+
return layout_map
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _extract_pipeline_stage_num(strategy_file):
|
|
232
|
+
"""extract pipeline stage num"""
|
|
233
|
+
pipeline_stage_num = 1
|
|
234
|
+
if strategy_file is not None:
|
|
235
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
236
|
+
layout_map = _convert_to_list(src_strategy)
|
|
237
|
+
pipeline_stage_set = set()
|
|
238
|
+
for _, layout in layout_map.items():
|
|
239
|
+
pipeline_stage_set.update(layout[6])
|
|
240
|
+
pipeline_stage_num = len(pipeline_stage_set)
|
|
241
|
+
if list(pipeline_stage_set) != list(range(pipeline_stage_num)):
|
|
242
|
+
raise ValueError("The strategy file for pipeline parallel dose not contains all stages.")
|
|
243
|
+
return pipeline_stage_num
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _extract_src_dst_layout_map(rank_id, src_strategy_file=None, dst_strategy_file=None):
|
|
247
|
+
"""Extract strategy list"""
|
|
248
|
+
src_layout_map = _extract_layout_map(src_strategy_file)
|
|
249
|
+
dst_layout_map = _extract_layout_map(dst_strategy_file)
|
|
250
|
+
if dst_layout_map is None:
|
|
251
|
+
return src_layout_map, dst_layout_map
|
|
252
|
+
dst_stage_device_num = np.prod(dst_layout_map.get(list(dst_layout_map.keys())[0])[0])
|
|
253
|
+
dst_stage_id = rank_id // dst_stage_device_num
|
|
254
|
+
# cut the source and destination layout, remain the parameter in the dst_stage
|
|
255
|
+
for param_name in list(dst_layout_map.keys()):
|
|
256
|
+
if dst_stage_id in dst_layout_map.get(param_name)[6]:
|
|
257
|
+
continue
|
|
258
|
+
dst_layout_map.pop(param_name)
|
|
259
|
+
if src_layout_map is not None and param_name in src_layout_map:
|
|
260
|
+
src_layout_map.pop(param_name)
|
|
261
|
+
return src_layout_map, dst_layout_map
|
|
262
|
+
|
|
263
|
+
|
|
109
264
|
def _restore_group_info_list(group_info_file_name):
|
|
110
265
|
"""restore group info"""
|
|
111
266
|
parallel_group_map = ms.train.node_strategy_pb2.ParallelGroupMap()
|
|
@@ -122,6 +277,7 @@ def _restore_group_info_list(group_info_file_name):
|
|
|
122
277
|
|
|
123
278
|
|
|
124
279
|
def _get_device_num_from_strategy(strategy_file=None):
|
|
280
|
+
"""Get device num from strategy file"""
|
|
125
281
|
if strategy_file is None:
|
|
126
282
|
return 1
|
|
127
283
|
src_strategy = _build_searched_strategy(strategy_file)
|
|
@@ -130,23 +286,14 @@ def _get_device_num_from_strategy(strategy_file=None):
|
|
|
130
286
|
return np.prod(device_mat)
|
|
131
287
|
|
|
132
288
|
|
|
133
|
-
def _rank_list_for_transform_parallel_checkpoint(rank_id,
|
|
289
|
+
def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst_strategy_list):
|
|
134
290
|
"""
|
|
135
291
|
Get the needed rank list for transform model parallel dim of checkpoint.
|
|
136
292
|
"""
|
|
137
|
-
if src_strategy_file is None:
|
|
138
|
-
return [rank_id]
|
|
139
|
-
src_strategy = _build_searched_strategy(src_strategy_file)
|
|
140
|
-
src_strategy_list = _convert_to_list(src_strategy)
|
|
141
|
-
if not src_strategy_list:
|
|
142
|
-
raise ValueError("The src_strategy_file is empty.")
|
|
143
|
-
if dst_strategy_file is not None:
|
|
144
|
-
dst_strategy = _build_searched_strategy(dst_strategy_file)
|
|
145
|
-
dst_strategy_list = _convert_to_list(dst_strategy)
|
|
146
293
|
result_list = set()
|
|
147
294
|
handled_layout = []
|
|
148
295
|
for param_name, _ in src_strategy_list.items():
|
|
149
|
-
if
|
|
296
|
+
if dst_strategy_list is not None and param_name not in dst_strategy_list:
|
|
150
297
|
continue
|
|
151
298
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
|
152
299
|
src_strategy_list.get(param_name))
|
|
@@ -156,7 +303,7 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_file=None
|
|
|
156
303
|
to_tensor_map = [-1] * len(fake_tensor_shape)
|
|
157
304
|
to_opt_shard_step = 0
|
|
158
305
|
to_opt_shard_size = 0
|
|
159
|
-
if
|
|
306
|
+
if dst_strategy_list is not None:
|
|
160
307
|
to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
161
308
|
dst_strategy_list.get(param_name))
|
|
162
309
|
handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
|
|
@@ -188,18 +335,10 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_file=None
|
|
|
188
335
|
return list(result_list)
|
|
189
336
|
|
|
190
337
|
|
|
191
|
-
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict,
|
|
192
|
-
dst_strategy_file=None):
|
|
338
|
+
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list, dst_strategy_list):
|
|
193
339
|
"""
|
|
194
340
|
Transform model parallel dimension for distributed checkpoint files.
|
|
195
341
|
"""
|
|
196
|
-
device_num = rank_id + 1
|
|
197
|
-
if src_strategy_file is not None:
|
|
198
|
-
src_strategy = _build_searched_strategy(src_strategy_file)
|
|
199
|
-
src_strategy_list = _convert_to_list(src_strategy)
|
|
200
|
-
if dst_strategy_file is not None:
|
|
201
|
-
dst_strategy = _build_searched_strategy(dst_strategy_file)
|
|
202
|
-
dst_strategy_list = _convert_to_list(dst_strategy)
|
|
203
342
|
transform_param_dict = {}
|
|
204
343
|
for param_name, _ in param_total_dict.items():
|
|
205
344
|
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
@@ -207,7 +346,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
207
346
|
from_tensor_map = [-1] * len(tensor_shape)
|
|
208
347
|
from_opt_shard_step = 0
|
|
209
348
|
from_opt_shard_size = 0
|
|
210
|
-
if
|
|
349
|
+
if src_strategy_list is not None:
|
|
211
350
|
if param_name not in src_strategy_list:
|
|
212
351
|
ms.log.warning("The parameter {} is not in src_strategy.".format(param_name))
|
|
213
352
|
continue
|
|
@@ -217,7 +356,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
217
356
|
to_tensor_map_origin = [-1] * len(tensor_shape)
|
|
218
357
|
to_opt_shard_step = 0
|
|
219
358
|
to_opt_shard_size = 0
|
|
220
|
-
if
|
|
359
|
+
if dst_strategy_list is not None:
|
|
221
360
|
if param_name not in dst_strategy_list:
|
|
222
361
|
ms.log.warning("The parameter {} is not in dst_strategy.".format(param_name))
|
|
223
362
|
continue
|
|
@@ -40,7 +40,7 @@ def _need_reset_device_target_for_ps(target):
|
|
|
40
40
|
For Ascend backend, the card can't be occupied by multiple processes in distributed traning,
|
|
41
41
|
so we need to reset the device target for some roles.
|
|
42
42
|
'''
|
|
43
|
-
is_server = (
|
|
43
|
+
is_server = (os.getenv('MS_ROLE') in ["MS_PSERVER", "MS_SERVER", "MS_SCHED"])
|
|
44
44
|
return is_server and target == "Ascend"
|
|
45
45
|
|
|
46
46
|
|
|
@@ -184,10 +184,6 @@ def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_si
|
|
|
184
184
|
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
|
|
185
185
|
|
|
186
186
|
|
|
187
|
-
def _insert_weight_init_info(name, global_seed, op_seed):
|
|
188
|
-
ps_context().insert_weight_init_info(name, global_seed, op_seed)
|
|
189
|
-
|
|
190
|
-
|
|
191
187
|
def _insert_accumu_init_info(name, init_val):
|
|
192
188
|
ps_context().insert_accumu_init_info(name, init_val)
|
|
193
189
|
|
|
@@ -210,6 +206,14 @@ def _cache_enable():
|
|
|
210
206
|
return ps_context().cache_enable()
|
|
211
207
|
|
|
212
208
|
|
|
209
|
+
def _set_cache_size(cache_size):
|
|
210
|
+
ps_context().set_cache_size(cache_size)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _set_sparse_format(sparse_format):
|
|
214
|
+
ps_context().set_sparse_format(sparse_format)
|
|
215
|
+
|
|
216
|
+
|
|
213
217
|
def _set_rank_id(rank_id):
|
|
214
218
|
ps_context().set_rank_id(rank_id)
|
|
215
219
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Context for recovery"""
|
|
16
16
|
|
|
17
|
-
from mindspore
|
|
17
|
+
from mindspore import _checkparam as Validator
|
|
18
18
|
from mindspore._c_expression import RecoveryContext
|
|
19
19
|
|
|
20
20
|
RECOVERY_CONTEXT = None
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""load tensor and combine tensor"""
|
|
16
16
|
from __future__ import division
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
|
|
18
19
|
import numpy as np
|
|
20
|
+
|
|
19
21
|
from mindspore.common.tensor import Tensor
|
|
20
22
|
from mindspore.communication.management import get_rank, get_group_size
|
|
21
23
|
from mindspore._c_expression import TensorTransform
|
|
@@ -173,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
|
|
173
175
|
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
|
174
176
|
|
|
175
177
|
|
|
176
|
-
def _get_slice_index(dev_mat, tensor_map):
|
|
178
|
+
def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
|
177
179
|
"""
|
|
178
180
|
Get the slice index for current slice.
|
|
179
181
|
|
|
180
182
|
Args:
|
|
181
183
|
dev_mat (list): The device matrix of devices.
|
|
182
184
|
tensor_map (list): The split strategy of tensor.
|
|
185
|
+
opt_shard_group(string): The group of optimizer shard
|
|
183
186
|
|
|
184
187
|
Returns:
|
|
185
188
|
Integer, the slice index for slice on this device.
|
|
186
189
|
"""
|
|
187
190
|
rank = get_rank()
|
|
191
|
+
dev_num = get_group_size()
|
|
188
192
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
189
193
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
194
|
+
if opt_shard_group:
|
|
195
|
+
tensor_slice_index += dev_num
|
|
196
|
+
opt_rank = get_rank(opt_shard_group)
|
|
197
|
+
tensor_slice_index += opt_rank
|
|
190
198
|
return tensor_slice_index
|
|
191
199
|
|
|
192
200
|
|
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
"""
|
|
16
16
|
NOTE:
|
|
17
17
|
Transformer Networks.
|
|
18
|
-
|
|
18
|
+
These are experimental APIs that are subject to change or deletion.
|
|
19
19
|
"""
|
|
20
20
|
from __future__ import absolute_import
|
|
21
21
|
|
|
22
|
-
from mindspore.
|
|
22
|
+
from mindspore.parallel._transformer.transformer import AttentionMask, VocabEmbedding, MultiHeadAttention, \
|
|
23
23
|
FeedForward, TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer, \
|
|
24
24
|
Transformer, TransformerOpParallelConfig, EmbeddingOpParallelConfig, TransformerRecomputeConfig
|
|
25
|
-
from mindspore.
|
|
26
|
-
from mindspore.
|
|
27
|
-
from mindspore.
|
|
28
|
-
from mindspore.
|
|
25
|
+
from mindspore.parallel._transformer.moe import MoEConfig
|
|
26
|
+
from mindspore.parallel._transformer.layers import FixedSparseAttention
|
|
27
|
+
from mindspore.parallel._transformer.loss import CrossEntropyLoss
|
|
28
|
+
from mindspore.parallel._transformer.op_parallel_config import OpParallelConfig
|
|
29
29
|
|
|
30
30
|
__all__ = []
|
|
31
31
|
__all__.extend(transformer.__all__)
|
|
@@ -33,11 +33,12 @@ from mindspore._extends import cell_attr_register
|
|
|
33
33
|
from mindspore.nn.cell import Cell
|
|
34
34
|
from mindspore.nn.layer.activation import get_activation
|
|
35
35
|
from mindspore.ops import functional as F
|
|
36
|
-
from mindspore
|
|
36
|
+
from mindspore import _checkparam as Validator
|
|
37
37
|
from mindspore.ops.primitive import constexpr
|
|
38
38
|
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
|
39
39
|
from mindspore.context import ParallelMode
|
|
40
|
-
from mindspore.
|
|
40
|
+
from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
|
|
41
|
+
from mindspore import log as logger
|
|
41
42
|
|
|
42
43
|
__all__ = [
|
|
43
44
|
"FixedSparseAttention"
|
|
@@ -154,6 +155,29 @@ class _LayerInputCheck:
|
|
|
154
155
|
f"but got {input_shape[dim]}")
|
|
155
156
|
return True
|
|
156
157
|
|
|
158
|
+
@staticmethod
|
|
159
|
+
def check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
|
|
160
|
+
"""
|
|
161
|
+
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
|
|
162
|
+
batch size will not be checked.
|
|
163
|
+
"""
|
|
164
|
+
length, hidden = target_shape
|
|
165
|
+
if isinstance(input_shape, tuple):
|
|
166
|
+
input_shape = list(input_shape)
|
|
167
|
+
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
|
168
|
+
[len(target_shape), len(target_shape) + 1])
|
|
169
|
+
if input_shape[-1] != hidden:
|
|
170
|
+
raise ValueError(f"For {func_name}, the last dimension of {param_name} shape must be {hidden},"
|
|
171
|
+
f"but got the last dimension {input_shape[-1]} in {input_shape}.")
|
|
172
|
+
if input_shape[0] == 0:
|
|
173
|
+
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape greater than 0,"
|
|
174
|
+
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
|
175
|
+
if len(input_shape) == 2 and input_shape[0] % length != 0:
|
|
176
|
+
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape should be divisible "
|
|
177
|
+
f"by {length}, "
|
|
178
|
+
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
|
179
|
+
return True
|
|
180
|
+
|
|
157
181
|
|
|
158
182
|
@constexpr
|
|
159
183
|
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
|
|
@@ -175,23 +199,6 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|
|
175
199
|
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
|
176
200
|
|
|
177
201
|
|
|
178
|
-
@constexpr
|
|
179
|
-
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
|
180
|
-
# check the input length
|
|
181
|
-
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
@constexpr
|
|
185
|
-
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
|
186
|
-
# check the input length
|
|
187
|
-
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
@constexpr
|
|
191
|
-
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
|
|
192
|
-
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
|
|
193
|
-
|
|
194
|
-
|
|
195
202
|
class _Dropout(nn.Cell):
|
|
196
203
|
r"""
|
|
197
204
|
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
|
@@ -361,7 +368,7 @@ class _Linear(Cell):
|
|
|
361
368
|
same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
|
362
369
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
|
363
370
|
activation (str): activate function applied to the output of the fully connected layer,
|
|
364
|
-
eg. 'ReLU'.Default: None.
|
|
371
|
+
eg. 'ReLU'. Default: None.
|
|
365
372
|
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
|
|
366
373
|
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
|
367
374
|
outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
|
|
@@ -392,7 +399,6 @@ class _Linear(Cell):
|
|
|
392
399
|
@_args_type_validator_check(in_channels=Validator.check_positive_int,
|
|
393
400
|
out_channels=Validator.check_positive_int,
|
|
394
401
|
has_bias=Validator.check_bool,
|
|
395
|
-
activation=_valid_type_checks([type(None), str], "Linear"),
|
|
396
402
|
transpose_b=Validator.check_bool,
|
|
397
403
|
expert_num=Validator.check_positive_int,
|
|
398
404
|
outer_batch=Validator.check_positive_int,
|
|
@@ -416,7 +422,9 @@ class _Linear(Cell):
|
|
|
416
422
|
super(_Linear, self).__init__()
|
|
417
423
|
self.in_channels = in_channels
|
|
418
424
|
self.out_channels = out_channels
|
|
419
|
-
if isinstance(
|
|
425
|
+
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
|
|
426
|
+
raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
|
|
427
|
+
if isinstance(weight_init, Tensor) and (weight_init.ndim != 2 or weight_init.shape[0] != out_channels or
|
|
420
428
|
weight_init.shape[1] != in_channels):
|
|
421
429
|
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
|
|
422
430
|
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
|
|
@@ -449,7 +457,10 @@ class _Linear(Cell):
|
|
|
449
457
|
self.bias.parallel_optimizer = False
|
|
450
458
|
self.bias_add = P.Add()
|
|
451
459
|
self.act_name = activation
|
|
452
|
-
|
|
460
|
+
if callable(activation):
|
|
461
|
+
self.activation = activation()
|
|
462
|
+
else:
|
|
463
|
+
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
|
453
464
|
self.activation_flag = self.activation is not None
|
|
454
465
|
self.dtype = compute_dtype
|
|
455
466
|
self.cast = P.Cast()
|
|
@@ -491,7 +502,7 @@ class _Linear(Cell):
|
|
|
491
502
|
self.matmul.shard(strategy_matmul)
|
|
492
503
|
if self.has_bias:
|
|
493
504
|
self.bias_add.shard(strategy_bias)
|
|
494
|
-
if self.activation_flag:
|
|
505
|
+
if self.activation_flag and isinstance(self.act_name, str):
|
|
495
506
|
# some operations has many primitives, need to manually set the shard
|
|
496
507
|
if self.act_name.lower() == "leakyrelu":
|
|
497
508
|
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
|
|
@@ -506,7 +517,26 @@ class _Linear(Cell):
|
|
|
506
517
|
"or auto parallel mode.")
|
|
507
518
|
else:
|
|
508
519
|
getattr(self.activation, self.act_name).shard(strategy_activation)
|
|
509
|
-
|
|
520
|
+
elif self.activation_flag and isinstance(self.activation, Cell):
|
|
521
|
+
if hasattr(self.activation, 'activation_shard') and strategy_activation:
|
|
522
|
+
shard_tuple = strategy_activation[0]
|
|
523
|
+
if len(shard_tuple) == 2:
|
|
524
|
+
parallel_config = OpParallelConfig(data_parallel=shard_tuple[0],
|
|
525
|
+
model_parallel=shard_tuple[1])
|
|
526
|
+
elif len(shard_tuple) == 4:
|
|
527
|
+
parallel_config = MoEParallelConfig(data_parallel=shard_tuple[0],
|
|
528
|
+
expert_parallel=shard_tuple[1],
|
|
529
|
+
model_parallel=shard_tuple[2])
|
|
530
|
+
else:
|
|
531
|
+
raise ValueError("The user-defined activation function currently only supports the case where the "
|
|
532
|
+
"input policy is 2 or 4, so that relevant policies can be extracted from it."
|
|
533
|
+
"To avoid this error, you need to add the function of extracting "
|
|
534
|
+
"'ParallelConfig' or 'OpParallelConfig' for the incoming strategy_activation ")
|
|
535
|
+
self.activation.activation_shard(parallel_config)
|
|
536
|
+
else:
|
|
537
|
+
logger.warning(f"The user passed the custom defined activation function {self.activation_flag}. "
|
|
538
|
+
f"If the user want to enable shard for the activation cell, "
|
|
539
|
+
f"the user should set the shard for each primitives in the cell.")
|
|
510
540
|
return self
|
|
511
541
|
|
|
512
542
|
|
|
@@ -540,13 +570,13 @@ class FixedSparseAttention(nn.Cell):
|
|
|
540
570
|
default args.
|
|
541
571
|
|
|
542
572
|
Inputs:
|
|
543
|
-
- **q** (Tensor) - Tensor query (
|
|
573
|
+
- **q** (Tensor) - Tensor query ( `mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
|
|
544
574
|
queries to query the context.
|
|
545
|
-
- **k** (Tensor) - Tensor key (
|
|
575
|
+
- **k** (Tensor) - Tensor key ( `mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
|
|
546
576
|
queries to query the context.
|
|
547
|
-
- **v** (Tensor) - Tensor value (
|
|
577
|
+
- **v** (Tensor) - Tensor value ( `mstype.fp16` [batch size, sequence length, Embedding Size]):
|
|
548
578
|
Sequence of queries to query the context.
|
|
549
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (
|
|
579
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of ( `mstype.fp32`, `mstype.fp16`
|
|
550
580
|
[batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
|
|
551
581
|
|
|
552
582
|
Outputs:
|
|
@@ -654,17 +684,9 @@ class FixedSparseAttention(nn.Cell):
|
|
|
654
684
|
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
|
|
655
685
|
|
|
656
686
|
def construct(self, q, k, v, attention_mask):
|
|
657
|
-
_check_shape_equal(F.shape(q), "q", self.cls_name,
|
|
658
|
-
[self.batch_size, self.seq_length, self.hidden_size])
|
|
659
687
|
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
|
|
660
|
-
_check_shape_equal(F.shape(k), "k", self.cls_name,
|
|
661
|
-
[self.batch_size, self.seq_length, self.hidden_size])
|
|
662
688
|
_check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name)
|
|
663
|
-
_check_shape_equal(F.shape(v), "v", self.cls_name,
|
|
664
|
-
[self.batch_size, self.seq_length, self.hidden_size])
|
|
665
689
|
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
|
|
666
|
-
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
|
667
|
-
[self.batch_size, self.seq_length, self.seq_length])
|
|
668
690
|
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
|
669
691
|
|
|
670
692
|
q, k, v = self._transpose_inputs(q, k, v)
|
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""
|
|
16
|
-
Parallel Loss for the Parallel Training
|
|
17
|
-
|
|
16
|
+
Parallel Loss for the Parallel Training.
|
|
17
|
+
These are experimental APIs that are subject to change or deletion.
|
|
18
18
|
"""
|
|
19
19
|
from __future__ import absolute_import
|
|
20
20
|
|
|
@@ -30,8 +30,8 @@ from mindspore.context import ParallelMode
|
|
|
30
30
|
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
|
|
31
31
|
from mindspore.log import _LogActionOnce
|
|
32
32
|
from mindspore import log as logger
|
|
33
|
-
from mindspore.
|
|
34
|
-
from mindspore.
|
|
33
|
+
from mindspore.parallel._transformer.layers import _check_input_dtype
|
|
34
|
+
from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
|
|
35
35
|
|
|
36
36
|
__all__ = ["CrossEntropyLoss"]
|
|
37
37
|
|
|
@@ -247,7 +247,4 @@ class CrossEntropyLoss(Cell):
|
|
|
247
247
|
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
|
|
248
248
|
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
|
|
249
249
|
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
|
|
250
|
-
_check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
|
|
251
|
-
_check_input_shape(F.shape(label), "label", self.cls_name, 1)
|
|
252
|
-
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
|
|
253
250
|
return True
|