mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/parallel/_tensor.py
CHANGED
|
@@ -17,7 +17,6 @@ from __future__ import division
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
|
-
|
|
21
20
|
from mindspore.common.tensor import Tensor
|
|
22
21
|
from mindspore.communication.management import get_rank, get_group_size
|
|
23
22
|
from mindspore._c_expression import TensorTransform
|
|
@@ -41,7 +40,7 @@ def _get_tensor_strategy(dev_mat, tensor_map):
|
|
|
41
40
|
if dim == -1:
|
|
42
41
|
tensor_strategy.append(1)
|
|
43
42
|
else:
|
|
44
|
-
tensor_strategy.append(dev_mat[-dim-1])
|
|
43
|
+
tensor_strategy.append(dev_mat[-dim - 1])
|
|
45
44
|
return tensor_strategy
|
|
46
45
|
|
|
47
46
|
|
|
@@ -198,7 +197,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
|
|
198
197
|
return tensor_slice_index
|
|
199
198
|
|
|
200
199
|
|
|
201
|
-
def _load_tensor(tensor, dev_mat, tensor_map):
|
|
200
|
+
def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
|
202
201
|
"""
|
|
203
202
|
Get the tensor slice of the local device by the device matrix and the tensor map
|
|
204
203
|
|
|
@@ -216,7 +215,10 @@ def _load_tensor(tensor, dev_mat, tensor_map):
|
|
|
216
215
|
>>> tensor_map = [1, -1]
|
|
217
216
|
>>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
|
218
217
|
"""
|
|
219
|
-
|
|
218
|
+
if rank_id == -1:
|
|
219
|
+
rank = get_rank()
|
|
220
|
+
else:
|
|
221
|
+
rank = rank_id
|
|
220
222
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
221
223
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
222
224
|
np_tensor = tensor.asnumpy()
|
|
@@ -225,7 +227,7 @@ def _load_tensor(tensor, dev_mat, tensor_map):
|
|
|
225
227
|
return np_tensor_slice
|
|
226
228
|
|
|
227
229
|
|
|
228
|
-
def _load_tensor_by_layout(tensor, layout):
|
|
230
|
+
def _load_tensor_by_layout(tensor, layout, rank_id):
|
|
229
231
|
"""
|
|
230
232
|
Load tensor by layout.
|
|
231
233
|
|
|
@@ -246,13 +248,13 @@ def _load_tensor_by_layout(tensor, layout):
|
|
|
246
248
|
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
|
247
249
|
dev_mat = layout[0]
|
|
248
250
|
tensor_map = layout[1]
|
|
251
|
+
if not tensor_map:
|
|
252
|
+
return tensor
|
|
249
253
|
uniform_split = layout[4]
|
|
250
254
|
group = layout[5]
|
|
251
255
|
if uniform_split == 0:
|
|
252
256
|
raise RuntimeError("The load tensor only support uniform split now")
|
|
253
|
-
|
|
254
|
-
return tensor
|
|
255
|
-
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
|
257
|
+
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, rank_id)
|
|
256
258
|
if group:
|
|
257
259
|
# get a totally shard tensor slice for parallel optimizer
|
|
258
260
|
rank = get_rank(group)
|
|
@@ -315,7 +317,6 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|
|
315
317
|
return Tensor(tensor_slices_new[0])
|
|
316
318
|
|
|
317
319
|
|
|
318
|
-
|
|
319
320
|
def _extract_layout_item(layout_item):
|
|
320
321
|
dev_matrix = layout_item[0]
|
|
321
322
|
tensor_map = layout_item[1]
|
|
@@ -541,6 +542,7 @@ def _check_operator(operator):
|
|
|
541
542
|
|
|
542
543
|
def _apply_operator(operator_name):
|
|
543
544
|
"""apply transform operator"""
|
|
545
|
+
|
|
544
546
|
def _apply_reshape_operator(numpy_data, reshape_op):
|
|
545
547
|
"""
|
|
546
548
|
Apply reshape operator.
|
|
@@ -597,8 +599,8 @@ def _apply_operator(operator_name):
|
|
|
597
599
|
raise ValueError("The slice operator information is wrong.")
|
|
598
600
|
shape_size = len(slice_op[1]) // 3
|
|
599
601
|
begin = slice_op[1][:shape_size]
|
|
600
|
-
end = slice_op[1][shape_size:shape_size*2]
|
|
601
|
-
stride = slice_op[1][shape_size*2:]
|
|
602
|
+
end = slice_op[1][shape_size:shape_size * 2]
|
|
603
|
+
stride = slice_op[1][shape_size * 2:]
|
|
602
604
|
slice_index = []
|
|
603
605
|
for begin_i, end_i, strides_i in zip(begin, end, stride):
|
|
604
606
|
s = slice(begin_i, end_i, strides_i)
|
|
@@ -637,8 +639,8 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
|
|
637
639
|
for i in range(len(tensor_slices[0][0])):
|
|
638
640
|
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
|
|
639
641
|
for j in range(1, device_count):
|
|
640
|
-
tensor_slices_new = np.concatenate((tensor_slices_new
|
|
641
|
-
|
|
642
|
+
tensor_slices_new = np.concatenate((tensor_slices_new, \
|
|
643
|
+
np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
|
|
642
644
|
tensor_slices_col.append(tensor_slices_new)
|
|
643
645
|
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
|
|
644
646
|
for i in range(1, len(tensor_slices_col)):
|
|
@@ -366,14 +366,14 @@ class _Linear(Cell):
|
|
|
366
366
|
is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
|
|
367
367
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
368
368
|
same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
|
369
|
-
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True
|
|
369
|
+
has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True``.
|
|
370
370
|
activation (str): activate function applied to the output of the fully connected layer,
|
|
371
|
-
eg. 'ReLU'. Default: None
|
|
371
|
+
eg. 'ReLU'. Default: ``None``.
|
|
372
372
|
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
|
|
373
373
|
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
|
374
374
|
outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
|
|
375
375
|
Default: 1.
|
|
376
|
-
expert_group_size (int): The number of tokens in each data parallel group. Default: None
|
|
376
|
+
expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
|
|
377
377
|
compute_dtype (dtype.Number): The computation type. Default: mstype.float16
|
|
378
378
|
Inputs:
|
|
379
379
|
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
|
@@ -424,9 +424,11 @@ class _Linear(Cell):
|
|
|
424
424
|
self.out_channels = out_channels
|
|
425
425
|
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
|
|
426
426
|
raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
427
|
+
|
|
428
|
+
if isinstance(weight_init, Tensor):
|
|
429
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels \
|
|
430
|
+
or weight_init.shape[1] != in_channels:
|
|
431
|
+
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
|
|
430
432
|
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
|
|
431
433
|
self.expert_num = expert_num
|
|
432
434
|
self.outer_batch = outer_batch
|
|
@@ -139,6 +139,7 @@ class _NLLLoss(Cell):
|
|
|
139
139
|
self.add = P.Add().shard(((dp, mp), ()))
|
|
140
140
|
|
|
141
141
|
def construct(self, softmax_result, one_hot_label):
|
|
142
|
+
"""The forward of _NLLLoss"""
|
|
142
143
|
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
|
143
144
|
loss = self.mul(log_softmax_result, one_hot_label)
|
|
144
145
|
loss_unsum = self.neg(loss)
|
|
@@ -51,13 +51,13 @@ class MoEConfig:
|
|
|
51
51
|
router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
|
|
52
52
|
num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger
|
|
53
53
|
than expert_num. Default: 1.
|
|
54
|
-
expert_group_size (int): The number of tokens in each data parallel group. Default: None
|
|
55
|
-
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
|
54
|
+
expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
|
|
55
|
+
This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
|
56
56
|
group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
|
|
57
|
-
time by converting part of inter communication into intra communication. Default: False
|
|
58
|
-
is effective only when model parallel > 1 and data_parallel equal to expert parallel.
|
|
57
|
+
time by converting part of inter communication into intra communication. Default: ``False``.
|
|
58
|
+
This parameter is effective only when model parallel > 1 and data_parallel equal to expert parallel.
|
|
59
59
|
comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
|
|
60
|
-
communicattion time by splitting and overlapping compute and communication. Default: False
|
|
60
|
+
communicattion time by splitting and overlapping compute and communication. Default: ``False``.
|
|
61
61
|
comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
|
|
62
62
|
the more overlap there will be but will consume more memory. Default: 2. This parameter is effective
|
|
63
63
|
only when comp_comm_parallel enable.
|
|
@@ -273,7 +273,7 @@ class MoE(Cell):
|
|
|
273
273
|
if self.group_wise_a2a:
|
|
274
274
|
# If capacity can't div by mp, pad for mp shard.
|
|
275
275
|
if capacity % self.mp != 0:
|
|
276
|
-
pad_size = self.mp-(capacity % self.mp)
|
|
276
|
+
pad_size = self.mp - (capacity % self.mp)
|
|
277
277
|
if pad_size != 0:
|
|
278
278
|
capacity += pad_size
|
|
279
279
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
@@ -330,7 +330,7 @@ class MoE(Cell):
|
|
|
330
330
|
# Pad capacity for comp_comm_parallel_degree split.
|
|
331
331
|
pad_size = 0
|
|
332
332
|
if capacity % self.comp_comm_parallel_degree != 0:
|
|
333
|
-
pad_size = self.comp_comm_parallel_degree-(capacity % self.comp_comm_parallel_degree)
|
|
333
|
+
pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
|
|
334
334
|
capacity += pad_size
|
|
335
335
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
336
336
|
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
|
@@ -147,10 +147,11 @@ class _PipeLineConfig(_Config):
|
|
|
147
147
|
>>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
|
|
148
148
|
"""
|
|
149
149
|
|
|
150
|
-
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
|
150
|
+
def __init__(self, pipeline_stage=1, micro_batch_num=1, pipeline_segment=1):
|
|
151
151
|
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
|
152
152
|
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
|
153
153
|
self.pipeline_stage = pipeline_stage
|
|
154
|
+
self.pipeline_segment = pipeline_segment
|
|
154
155
|
self.micro_batch_num = micro_batch_num
|
|
155
156
|
|
|
156
157
|
@property
|
|
@@ -163,6 +164,16 @@ class _PipeLineConfig(_Config):
|
|
|
163
164
|
self._pipeline_stage = value
|
|
164
165
|
context.set_auto_parallel_context(pipeline_stages=value)
|
|
165
166
|
|
|
167
|
+
@property
|
|
168
|
+
def pipeline_segment(self):
|
|
169
|
+
return self._pipeline_segment
|
|
170
|
+
|
|
171
|
+
@pipeline_segment.setter
|
|
172
|
+
def pipeline_segment(self, value):
|
|
173
|
+
Validator.check_positive_int(value, "pipeline_segment")
|
|
174
|
+
self._pipeline_segment = value
|
|
175
|
+
context.set_auto_parallel_context(pipeline_segments=value)
|
|
176
|
+
|
|
166
177
|
@property
|
|
167
178
|
def micro_batch_num(self):
|
|
168
179
|
return self._micro_batch_num
|
|
@@ -70,7 +70,7 @@ class EmbeddingOpParallelConfig(_Config):
|
|
|
70
70
|
vocab_emb_dp(bool): Shard embedding in model parallel or data parallel. If True, the embedding lookup
|
|
71
71
|
will be a data parallel style training and model_parallel value will be ignored. If false, the
|
|
72
72
|
embedding table will be sharded into n parts at the 0-th dimension row slice of the embedding table,
|
|
73
|
-
where the n is the model parallel way determined by this parameter. Default: True
|
|
73
|
+
where the n is the model parallel way determined by this parameter. Default: ``True``
|
|
74
74
|
|
|
75
75
|
Supported Platforms:
|
|
76
76
|
``Ascend`` ``GPU``
|
|
@@ -120,13 +120,13 @@ class TransformerRecomputeConfig(_Config):
|
|
|
120
120
|
TransformerRecomputeConfig for the setting recompute attributes for encoder/decoder layers.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
123
|
-
recompute (bool): Enable recomputation of the transformer block or not. Default: False
|
|
123
|
+
recompute (bool): Enable recomputation of the transformer block or not. Default: ``False``.
|
|
124
124
|
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
|
|
125
125
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
126
|
-
Default: False
|
|
126
|
+
Default: ``False``.
|
|
127
127
|
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
|
128
|
-
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True
|
|
129
|
-
recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: False
|
|
128
|
+
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True``.
|
|
129
|
+
recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: ``False``.
|
|
130
130
|
|
|
131
131
|
Supported Platforms:
|
|
132
132
|
``Ascend`` ``GPU``
|
|
@@ -214,7 +214,7 @@ class TransformerOpParallelConfig(_Config):
|
|
|
214
214
|
gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
|
|
215
215
|
recompute (Union[TransformerRecomputeConfig, bool]): The configuration of recomputation for
|
|
216
216
|
the transformer block. Default: An instance of TransformerRecomputeConfig with default values.
|
|
217
|
-
vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True
|
|
217
|
+
vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: ``True``.
|
|
218
218
|
|
|
219
219
|
Supported Platforms:
|
|
220
220
|
``Ascend`` ``GPU``
|
|
@@ -226,7 +226,8 @@ class TransformerOpParallelConfig(_Config):
|
|
|
226
226
|
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
|
|
227
227
|
"""
|
|
228
228
|
|
|
229
|
-
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1,
|
|
229
|
+
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, pipeline_segment=1,
|
|
230
|
+
micro_batch_num=1,
|
|
230
231
|
recompute=default_transformer_recompute_config,
|
|
231
232
|
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
|
232
233
|
self.recompute = recompute
|
|
@@ -234,7 +235,8 @@ class TransformerOpParallelConfig(_Config):
|
|
|
234
235
|
self.gradient_aggregation_group = gradient_aggregation_group
|
|
235
236
|
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
|
236
237
|
vocab_emb_dp=vocab_emb_dp)
|
|
237
|
-
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num
|
|
238
|
+
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num,
|
|
239
|
+
pipeline_segment=pipeline_segment)
|
|
238
240
|
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
|
239
241
|
expert_parallel=expert_parallel)
|
|
240
242
|
|
|
@@ -309,6 +311,14 @@ class TransformerOpParallelConfig(_Config):
|
|
|
309
311
|
def pipeline_stage(self, value):
|
|
310
312
|
self._pp_config.pipeline_stage = value
|
|
311
313
|
|
|
314
|
+
@property
|
|
315
|
+
def pipeline_segment(self):
|
|
316
|
+
return self._pp_config.pipeline_segment
|
|
317
|
+
|
|
318
|
+
@pipeline_segment.setter
|
|
319
|
+
def pipeline_segment(self, value):
|
|
320
|
+
self._pp_config.pipeline_segment = value
|
|
321
|
+
|
|
312
322
|
@property
|
|
313
323
|
def optimizer_shard(self):
|
|
314
324
|
return self._optimizer_shard
|
|
@@ -359,8 +369,8 @@ class FeedForward(Cell):
|
|
|
359
369
|
the `activation_shard` function. Please see examples. Default: gelu.
|
|
360
370
|
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
|
|
361
371
|
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
|
362
|
-
expert_group_size (int): The number of tokens in each data parallel group. Default: None
|
|
363
|
-
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
|
372
|
+
expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
|
|
373
|
+
This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
|
364
374
|
param_init_type (dtype.Number): The parameter initialization type. Should be mstype.float32 or
|
|
365
375
|
mstype.float16. Default: mstype.float32.
|
|
366
376
|
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see
|
|
@@ -429,6 +439,7 @@ class FeedForward(Cell):
|
|
|
429
439
|
>>> print(output.shape)
|
|
430
440
|
(2, 20, 15)
|
|
431
441
|
"""
|
|
442
|
+
|
|
432
443
|
@_LogActionOnce(logger=logger, key='FeedForward',
|
|
433
444
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
434
445
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -622,6 +633,7 @@ class AttentionMask(Cell):
|
|
|
622
633
|
[1. 1. 1. 0]
|
|
623
634
|
[0. 0. 0. 0]]]
|
|
624
635
|
"""
|
|
636
|
+
|
|
625
637
|
@_LogActionOnce(logger=logger, key='AttentionMask',
|
|
626
638
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
627
639
|
@_args_type_validator_check(seq_length=Validator.check_positive_int,
|
|
@@ -710,6 +722,7 @@ class VocabEmbedding(Cell):
|
|
|
710
722
|
>>> print(table.shape)
|
|
711
723
|
(30, 30)
|
|
712
724
|
"""
|
|
725
|
+
|
|
713
726
|
@_LogActionOnce(logger=logger, key='VocabEmbedding',
|
|
714
727
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
715
728
|
@_args_type_validator_check(vocab_size=Validator.check_positive_int,
|
|
@@ -866,6 +879,7 @@ class MultiHeadAttention(Cell):
|
|
|
866
879
|
>>> print(past[1].shape)
|
|
867
880
|
(2, 3, 20, 5)
|
|
868
881
|
"""
|
|
882
|
+
|
|
869
883
|
@_LogActionOnce(logger=logger, key='MultiHeadAttention',
|
|
870
884
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
871
885
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -1203,7 +1217,8 @@ class MultiHeadAttention(Cell):
|
|
|
1203
1217
|
def _get_batch_size_from_query(self, query):
|
|
1204
1218
|
r"""Get the batch size from query tensor"""
|
|
1205
1219
|
# For the incremental prediction, the seq length for the input is 1.
|
|
1206
|
-
|
|
1220
|
+
incr_infer = self.use_past and self.is_first_iteration
|
|
1221
|
+
if len(F.shape(query)) == 2 and ((incr_infer) or (not self.use_past)):
|
|
1207
1222
|
return F.shape(query)[0] // self.src_seq_length
|
|
1208
1223
|
return F.shape(query)[0]
|
|
1209
1224
|
|
|
@@ -1459,6 +1474,7 @@ class TransformerEncoderLayer(Cell):
|
|
|
1459
1474
|
>>> print(past[1].shape)
|
|
1460
1475
|
(2, 2, 16, 4)
|
|
1461
1476
|
"""
|
|
1477
|
+
|
|
1462
1478
|
@_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
|
|
1463
1479
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
1464
1480
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -1848,6 +1864,7 @@ class TransformerDecoderLayer(Cell):
|
|
|
1848
1864
|
>>> print(past[3].shape)
|
|
1849
1865
|
(2, 2, 20, 32)
|
|
1850
1866
|
"""
|
|
1867
|
+
|
|
1851
1868
|
@_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
|
|
1852
1869
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
1853
1870
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
|
@@ -2292,7 +2309,7 @@ class TransformerEncoder(Cell):
|
|
|
2292
2309
|
represents the transformer block, `layer_id(int)` means the layer index for the current module, counts
|
|
2293
2310
|
from zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net.
|
|
2294
2311
|
The default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
|
|
2295
|
-
Default: None
|
|
2312
|
+
Default: ``None``.
|
|
2296
2313
|
offset(int): The initial layer index for the `encoder`. Used for setting the fusion id and stage id, to not
|
|
2297
2314
|
overlap with the encoder layer. Default 0.
|
|
2298
2315
|
use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
|
|
@@ -2301,7 +2318,7 @@ class TransformerEncoder(Cell):
|
|
|
2301
2318
|
In the first step, set the is_first_iteration to be True by
|
|
2302
2319
|
`model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
|
|
2303
2320
|
is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
|
|
2304
|
-
pass the single step's input tensor, and loop it. Default: False
|
|
2321
|
+
pass the single step's input tensor, and loop it. Default: ``False``.
|
|
2305
2322
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
|
|
2306
2323
|
with default values. Please see `MoEConfig`.
|
|
2307
2324
|
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
|
|
@@ -2379,6 +2396,7 @@ class TransformerEncoder(Cell):
|
|
|
2379
2396
|
>>> print(past[0][1].shape)
|
|
2380
2397
|
(2, 2, 16, 4)
|
|
2381
2398
|
"""
|
|
2399
|
+
|
|
2382
2400
|
@_LogActionOnce(logger=logger, key='TransformerEncoder',
|
|
2383
2401
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2384
2402
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
|
@@ -2548,7 +2566,7 @@ class TransformerDecoder(Cell):
|
|
|
2548
2566
|
represents the transformer block, `layer_id(int)` means the layer index for the current module, counts
|
|
2549
2567
|
from zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net.
|
|
2550
2568
|
The default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
|
|
2551
|
-
Default: None
|
|
2569
|
+
Default: ``None``.
|
|
2552
2570
|
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
|
|
2553
2571
|
offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
|
|
2554
2572
|
overlap with the encoder layer. Default 0.
|
|
@@ -2613,6 +2631,7 @@ class TransformerDecoder(Cell):
|
|
|
2613
2631
|
>>> print(past[0][3].shape)
|
|
2614
2632
|
(2, 2, 20, 32)
|
|
2615
2633
|
"""
|
|
2634
|
+
|
|
2616
2635
|
@_LogActionOnce(logger=logger, key='TransformerDecoder',
|
|
2617
2636
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2618
2637
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
|
@@ -2882,6 +2901,7 @@ class Transformer(Cell):
|
|
|
2882
2901
|
>>> print(de_past[0][3].shape)
|
|
2883
2902
|
(2, 2, 20, 32)
|
|
2884
2903
|
"""
|
|
2904
|
+
|
|
2885
2905
|
@_LogActionOnce(logger=logger, key='Transformer',
|
|
2886
2906
|
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
2887
2907
|
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
mindspore/parallel/_utils.py
CHANGED
|
@@ -96,11 +96,21 @@ def _slice_parameter(parameter, phase, layout):
|
|
|
96
96
|
new_param = parameter.init_data(layout, set_sliced=True)
|
|
97
97
|
parameter = new_param
|
|
98
98
|
graph_executor.updata_param_node_default_input(phase, {parameter.name: parameter})
|
|
99
|
-
if
|
|
100
|
-
|
|
99
|
+
if layout is None:
|
|
100
|
+
parameter.sliced = True
|
|
101
|
+
return
|
|
102
|
+
if not parameter.sliced:
|
|
103
|
+
rank = get_rank()
|
|
104
|
+
new_tensor = _load_tensor_by_layout(parameter, layout, rank)
|
|
101
105
|
parameter.set_data(new_tensor, True)
|
|
102
106
|
|
|
103
107
|
|
|
108
|
+
def _slice_tensor(tensor, layout, rank_id):
|
|
109
|
+
"""Slice python tensor obj according to the layout."""
|
|
110
|
+
new_tensor = _load_tensor_by_layout(tensor, layout, rank_id)
|
|
111
|
+
return new_tensor
|
|
112
|
+
|
|
113
|
+
|
|
104
114
|
def _init_optimizer_state(parameter, phase):
|
|
105
115
|
"""init optimizer state"""
|
|
106
116
|
if not parameter.has_init:
|
|
@@ -127,14 +137,17 @@ def _to_full_shapes(shapes, device_num):
|
|
|
127
137
|
"dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
|
|
128
138
|
new_shape = ()
|
|
129
139
|
for i, item in enumerate(shape):
|
|
130
|
-
|
|
140
|
+
if item > 0:
|
|
141
|
+
new_shape += (item * dataset_strategy[index][i],) # static shape
|
|
142
|
+
else:
|
|
143
|
+
new_shape += (item,) # dynamic shape
|
|
131
144
|
new_shapes.append(new_shape)
|
|
132
145
|
return new_shapes
|
|
133
146
|
for shape in shapes:
|
|
134
147
|
new_shape = ()
|
|
135
148
|
for i, item in enumerate(shape):
|
|
136
|
-
if i == 0:
|
|
137
|
-
new_shape += (item * device_num,)
|
|
149
|
+
if i == 0 and item > 0:
|
|
150
|
+
new_shape += (item * device_num,) # only for static shape
|
|
138
151
|
else:
|
|
139
152
|
new_shape += (item,)
|
|
140
153
|
new_shapes.append(new_shape)
|
|
@@ -192,7 +205,7 @@ def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
|
|
|
192
205
|
slice_index += (s,)
|
|
193
206
|
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
|
|
194
207
|
new_tensor_numpy[slice_index] = data.asnumpy()
|
|
195
|
-
new_tensor = Tensor(new_tensor_numpy)
|
|
208
|
+
new_tensor = Tensor(new_tensor_numpy, dtype=type_)
|
|
196
209
|
lst.append(new_tensor)
|
|
197
210
|
if scaling_sens:
|
|
198
211
|
lst.append(Tensor(scaling_sens, mstype.float32))
|
|
@@ -325,7 +338,7 @@ def _parallel_predict_check():
|
|
|
325
338
|
dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
|
|
326
339
|
is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"))
|
|
327
340
|
if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp:
|
|
328
|
-
|
|
341
|
+
logger.warning('Using non full-batch dataset in model prediction may lead to incorrect data.')
|
|
329
342
|
|
|
330
343
|
|
|
331
344
|
def _check_similar_layout(tensor_layout1, tensor_layout2):
|
|
@@ -335,7 +348,7 @@ def _check_similar_layout(tensor_layout1, tensor_layout2):
|
|
|
335
348
|
for i in tensor_layout1[1]:
|
|
336
349
|
if i == -1:
|
|
337
350
|
continue
|
|
338
|
-
if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]:
|
|
351
|
+
if tensor_layout1[0][-1 - i] != tensor_layout2[0][-1 - i]:
|
|
339
352
|
return False
|
|
340
353
|
return True
|
|
341
354
|
|
|
@@ -353,7 +366,7 @@ def _remove_repeated_slices(tensor_layout):
|
|
|
353
366
|
tensor_map = tensor_layout[1]
|
|
354
367
|
for dim in range(len(dev_mat)):
|
|
355
368
|
if dim not in tensor_map:
|
|
356
|
-
dev_mat[-1-dim] = 1
|
|
369
|
+
dev_mat[-1 - dim] = 1
|
|
357
370
|
new_tensor_layout[0] = dev_mat
|
|
358
371
|
return new_tensor_layout
|
|
359
372
|
|
|
@@ -409,15 +422,24 @@ def _grads_divided_by_device_num_if_recomputation(grads):
|
|
|
409
422
|
"""
|
|
410
423
|
If in pynative parallel and full_batch is True, divide grads by device num to ensure that the gradients is correct.
|
|
411
424
|
"""
|
|
412
|
-
if not
|
|
425
|
+
if not _is_pynative_parallel() or not _get_full_batch():
|
|
413
426
|
return grads
|
|
414
427
|
|
|
415
|
-
device_num =
|
|
428
|
+
device_num = _get_device_num()
|
|
416
429
|
logger.info(f"In PyNative mode, when parallel mode is in "
|
|
417
430
|
f"({context.ParallelMode.SEMI_AUTO_PARALLEL}, {context.ParallelMode.AUTO_PARALLEL}) and "
|
|
418
431
|
f"full_batch is Ture, the gradients will be automatically divided by device_num({device_num}).")
|
|
419
|
-
new_grads = ()
|
|
420
|
-
for grad in grads:
|
|
421
|
-
new_grads += (grad / device_num,)
|
|
422
432
|
|
|
433
|
+
if not isinstance(grads, (tuple, Tensor)):
|
|
434
|
+
raise ValueError(f"The type of grads must be either Tuple[Tensor] or Tensor, but got {type(grads)}.")
|
|
435
|
+
|
|
436
|
+
if isinstance(grads, tuple):
|
|
437
|
+
new_grads = ()
|
|
438
|
+
if grads:
|
|
439
|
+
device_num_tensor = Tensor(device_num, grads[0].dtype)
|
|
440
|
+
for grad in grads:
|
|
441
|
+
new_grads += (grad / device_num_tensor,)
|
|
442
|
+
else:
|
|
443
|
+
device_num_tensor = Tensor(device_num, grads.dtype)
|
|
444
|
+
new_grads = grads / device_num_tensor
|
|
423
445
|
return new_grads
|