mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -52,12 +52,15 @@ class ReduceOp:
|
|
|
52
52
|
Before running the following examples, you need to configure the communication environment variables.
|
|
53
53
|
|
|
54
54
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
55
|
-
Please see the `
|
|
56
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
55
|
+
Please see the `rank table Startup
|
|
56
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
57
57
|
for more details.
|
|
58
58
|
|
|
59
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
60
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
59
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
60
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
61
|
+
|
|
62
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
63
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
61
64
|
|
|
62
65
|
This example should be run with multiple devices.
|
|
63
66
|
|
|
@@ -117,8 +120,9 @@ class AllReduce(Primitive):
|
|
|
117
120
|
|
|
118
121
|
Args:
|
|
119
122
|
op (str): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
120
|
-
On the CPU, only 'sum' is supported. Default: ReduceOp.SUM.
|
|
121
|
-
group (str): The communication group to work on. Default:
|
|
123
|
+
On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
|
|
124
|
+
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
|
|
125
|
+
means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
|
|
122
126
|
|
|
123
127
|
Inputs:
|
|
124
128
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
@@ -139,14 +143,17 @@ class AllReduce(Primitive):
|
|
|
139
143
|
Before running the following examples, you need to configure the communication environment variables.
|
|
140
144
|
|
|
141
145
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
142
|
-
Please see the `
|
|
143
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
146
|
+
Please see the `rank table Startup
|
|
147
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
144
148
|
for more details.
|
|
145
149
|
|
|
146
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
147
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
150
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
151
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
148
152
|
|
|
149
|
-
|
|
153
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
154
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
155
|
+
|
|
156
|
+
This example should be run with 2 devices.
|
|
150
157
|
|
|
151
158
|
>>> import numpy as np
|
|
152
159
|
>>> from mindspore.communication import init
|
|
@@ -170,6 +177,11 @@ class AllReduce(Primitive):
|
|
|
170
177
|
>>> print(output)
|
|
171
178
|
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
172
179
|
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
180
|
+
|
|
181
|
+
Tutorial Examples:
|
|
182
|
+
- `Distributed Set Communication Primitives - AllReduce
|
|
183
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#allreduce>`_
|
|
184
|
+
|
|
173
185
|
"""
|
|
174
186
|
|
|
175
187
|
@prim_attr_register
|
|
@@ -197,7 +209,8 @@ class AllGather(PrimitiveWithInfer):
|
|
|
197
209
|
- Currently only supports GRAPH_MODE and it should be called in Cell.
|
|
198
210
|
|
|
199
211
|
Args:
|
|
200
|
-
group (str): The communication group to work on. Default:
|
|
212
|
+
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
|
|
213
|
+
means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
|
|
201
214
|
|
|
202
215
|
Inputs:
|
|
203
216
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
@@ -219,12 +232,15 @@ class AllGather(PrimitiveWithInfer):
|
|
|
219
232
|
Before running the following examples, you need to configure the communication environment variables.
|
|
220
233
|
|
|
221
234
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
222
|
-
Please see the `
|
|
223
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
235
|
+
Please see the `rank table Startup
|
|
236
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
224
237
|
for more details.
|
|
225
238
|
|
|
226
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
227
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
239
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
240
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
241
|
+
|
|
242
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
243
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
228
244
|
|
|
229
245
|
This example should be run with 2 devices.
|
|
230
246
|
|
|
@@ -253,6 +269,11 @@ class AllGather(PrimitiveWithInfer):
|
|
|
253
269
|
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
254
270
|
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
255
271
|
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
272
|
+
|
|
273
|
+
Tutorial Examples:
|
|
274
|
+
- `Distributed Set Communication Primitives - AllGather
|
|
275
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#allgather>`_
|
|
276
|
+
|
|
256
277
|
"""
|
|
257
278
|
|
|
258
279
|
@prim_attr_register
|
|
@@ -268,9 +289,6 @@ class AllGather(PrimitiveWithInfer):
|
|
|
268
289
|
self.add_prim_attr('mean_flag', False)
|
|
269
290
|
self.add_prim_attr('no_eliminate', True)
|
|
270
291
|
|
|
271
|
-
def __call__(self, tensor):
|
|
272
|
-
raise NotImplementedError
|
|
273
|
-
|
|
274
292
|
def infer_shape(self, x_shape):
|
|
275
293
|
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
|
276
294
|
if x_shape[0] > 0:
|
|
@@ -288,8 +306,8 @@ class _MiniStepAllGather(PrimitiveWithInfer):
|
|
|
288
306
|
internal use of parallel modules and cannot be called by users.
|
|
289
307
|
|
|
290
308
|
Args:
|
|
291
|
-
group (str): The communication group to work on. Default: None.
|
|
292
|
-
grad_accumulation_step (int): The grad accumulation step. Default: None.
|
|
309
|
+
group (str): The communication group to work on. Default: ``None`` .
|
|
310
|
+
grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
|
|
293
311
|
"""
|
|
294
312
|
|
|
295
313
|
@prim_attr_register
|
|
@@ -324,7 +342,7 @@ class _MicroStepAllGather(PrimitiveWithInfer):
|
|
|
324
342
|
internal use of parallel modules and cannot be called by users.
|
|
325
343
|
|
|
326
344
|
Args:
|
|
327
|
-
group (str): The communication group to work on. Default: None.
|
|
345
|
+
group (str): The communication group to work on. Default: ``None`` .
|
|
328
346
|
"""
|
|
329
347
|
|
|
330
348
|
@prim_attr_register
|
|
@@ -364,7 +382,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|
|
364
382
|
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
|
|
365
383
|
|
|
366
384
|
Args:
|
|
367
|
-
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
|
|
385
|
+
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
|
|
368
386
|
|
|
369
387
|
Raises:
|
|
370
388
|
TypeError: If group is not a list nor tuple, or elements of group are not int.
|
|
@@ -410,16 +428,14 @@ class _HostAllGather(PrimitiveWithInfer):
|
|
|
410
428
|
class ReduceScatter(Primitive):
|
|
411
429
|
r"""
|
|
412
430
|
Reduces and scatters tensors from the specified communication group.
|
|
413
|
-
For more details about it, please refer to `Distributed Set Communication Primitives - ReduceScatter \
|
|
414
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#reducescatter>`_ .
|
|
415
431
|
|
|
416
432
|
Note:
|
|
417
433
|
The tensors must have the same shape and format in all processes of the collection.
|
|
418
434
|
|
|
419
435
|
Args:
|
|
420
|
-
op (str): Specifies an operation used for element-wise reductions,
|
|
421
|
-
like SUM and MAX. Default: ReduceOp.SUM.
|
|
422
|
-
group (str): The communication group to work on. Default:
|
|
436
|
+
op (str, optional): Specifies an operation used for element-wise reductions,
|
|
437
|
+
like SUM and MAX. Default: ``ReduceOp.SUM`` .
|
|
438
|
+
group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
|
|
423
439
|
|
|
424
440
|
Inputs:
|
|
425
441
|
- **input_x** (Tensor) - Input Tensor, suppose it has a shape :math:`(N, *)`, where `*`
|
|
@@ -441,12 +457,15 @@ class ReduceScatter(Primitive):
|
|
|
441
457
|
Before running the following examples, you need to configure the communication environment variables.
|
|
442
458
|
|
|
443
459
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
444
|
-
Please see the `
|
|
445
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
460
|
+
Please see the `rank table Startup
|
|
461
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
446
462
|
for more details.
|
|
447
463
|
|
|
448
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
449
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
464
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
465
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
466
|
+
|
|
467
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
468
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
450
469
|
|
|
451
470
|
This example should be run with 2 devices.
|
|
452
471
|
|
|
@@ -476,6 +495,11 @@ class ReduceScatter(Primitive):
|
|
|
476
495
|
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
477
496
|
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
478
497
|
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
498
|
+
|
|
499
|
+
Tutorial Examples:
|
|
500
|
+
- `Distributed Set Communication Primitives - ReduceScatter
|
|
501
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#reducescatter>`_
|
|
502
|
+
|
|
479
503
|
"""
|
|
480
504
|
|
|
481
505
|
@prim_attr_register
|
|
@@ -490,9 +514,6 @@ class ReduceScatter(Primitive):
|
|
|
490
514
|
self.add_prim_attr('fusion', 0)
|
|
491
515
|
self.add_prim_attr('no_eliminate', True)
|
|
492
516
|
|
|
493
|
-
def __call__(self, tensor):
|
|
494
|
-
raise NotImplementedError
|
|
495
|
-
|
|
496
517
|
|
|
497
518
|
class _HostReduceScatter(PrimitiveWithInfer):
|
|
498
519
|
"""
|
|
@@ -506,8 +527,8 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|
|
506
527
|
|
|
507
528
|
Args:
|
|
508
529
|
op (str): Specifies an operation used for element-wise reductions,
|
|
509
|
-
like sum, max, avg. Default: ReduceOp.SUM.
|
|
510
|
-
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
|
|
530
|
+
like sum, max, avg. Default: ``ReduceOp.SUM`` .
|
|
531
|
+
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
|
|
511
532
|
|
|
512
533
|
Raises:
|
|
513
534
|
TypeError: If op is not a string and group is not a list nor tuple,
|
|
@@ -558,7 +579,7 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
558
579
|
Args:
|
|
559
580
|
root_rank (int): Source rank. Required in all processes except the one
|
|
560
581
|
that is sending the data.
|
|
561
|
-
group (str): The communication group to work on. Default:
|
|
582
|
+
group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
|
|
562
583
|
|
|
563
584
|
Inputs:
|
|
564
585
|
- **input_x** (tuple[Tensor]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
@@ -578,12 +599,15 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
578
599
|
Before running the following examples, you need to configure the communication environment variables.
|
|
579
600
|
|
|
580
601
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
581
|
-
Please see the `
|
|
582
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
602
|
+
Please see the `rank table Startup
|
|
603
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
583
604
|
for more details.
|
|
584
605
|
|
|
585
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
586
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
606
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
607
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
608
|
+
|
|
609
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
610
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
587
611
|
|
|
588
612
|
This example should be run with multiple devices.
|
|
589
613
|
|
|
@@ -611,6 +635,11 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
611
635
|
(Tensor(shape[2,4], dtype=Int32, value=
|
|
612
636
|
[[1, 1, 1, 1],
|
|
613
637
|
[1, 1, 1, 1]]),)
|
|
638
|
+
|
|
639
|
+
Tutorial Examples:
|
|
640
|
+
- `Distributed Set Communication Primitives - Broadcast
|
|
641
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#broadcast>`_
|
|
642
|
+
|
|
614
643
|
"""
|
|
615
644
|
|
|
616
645
|
@prim_attr_register
|
|
@@ -629,7 +658,7 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
629
658
|
if not isinstance(x_dtype, tuple):
|
|
630
659
|
raise TypeError(f"For '{self.name}', the 'input_x' must be a tuple, but got {type(x_dtype).__name__}!")
|
|
631
660
|
for _ele in x_dtype:
|
|
632
|
-
check_collective_target_dtype('input_x', _ele, self.name)
|
|
661
|
+
check_collective_target_dtype('tuple input_x', _ele, self.name)
|
|
633
662
|
return x_dtype
|
|
634
663
|
|
|
635
664
|
|
|
@@ -671,7 +700,7 @@ class _AllSwap(PrimitiveWithCheck):
|
|
|
671
700
|
self.add_prim_attr('order_enforce_skip', True)
|
|
672
701
|
|
|
673
702
|
def __check__(self, tensor_in, send_size, recv_size):
|
|
674
|
-
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.
|
|
703
|
+
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor_type, self.name)
|
|
675
704
|
validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
|
|
676
705
|
self.name)
|
|
677
706
|
validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
|
|
@@ -699,11 +728,11 @@ class NeighborExchange(Primitive):
|
|
|
699
728
|
The user needs to preset
|
|
700
729
|
communication environment variables before running the following example, please check the details on the
|
|
701
730
|
official website of `MindSpore \
|
|
702
|
-
<https://www.mindspore.cn/docs/en/r2.
|
|
731
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.ops.primitive.html#communication-operator>`_.
|
|
703
732
|
|
|
704
733
|
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
|
|
705
734
|
in the same subnet, please check the `details \
|
|
706
|
-
<https://www.mindspore.cn/
|
|
735
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
|
|
707
736
|
|
|
708
737
|
Args:
|
|
709
738
|
send_rank_ids (list(int)): Ranks which the data is sent to.
|
|
@@ -711,7 +740,7 @@ class NeighborExchange(Primitive):
|
|
|
711
740
|
recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids.
|
|
712
741
|
send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
|
|
713
742
|
recv_type (type): Data type which received from recv_rank_ids
|
|
714
|
-
group (str): The communication group to work on. Default:
|
|
743
|
+
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
|
|
715
744
|
|
|
716
745
|
Inputs:
|
|
717
746
|
- **input_x** (tuple[Tensor]) - Shapes are same as args of send_shapes.
|
|
@@ -742,13 +771,18 @@ class NeighborExchange(Primitive):
|
|
|
742
771
|
... def construct(self, x):
|
|
743
772
|
... out = self.neighborexchange((x,))
|
|
744
773
|
...
|
|
745
|
-
>>> ms.set_context(mode=ms.GRAPH_MODE
|
|
774
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
746
775
|
>>> init()
|
|
747
776
|
>>> net = Net()
|
|
748
777
|
>>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
|
|
749
778
|
>>> output = net(input_x)
|
|
750
779
|
>>> print(output)
|
|
751
780
|
[[2. 2.], [2. 2.]]
|
|
781
|
+
|
|
782
|
+
Tutorial Examples:
|
|
783
|
+
- `Distributed Set Communication Primitives - NeighborExchange
|
|
784
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#neighborexchange>`_
|
|
785
|
+
|
|
752
786
|
"""
|
|
753
787
|
|
|
754
788
|
@prim_attr_register
|
|
@@ -760,6 +794,7 @@ class NeighborExchange(Primitive):
|
|
|
760
794
|
self.recv_shapes = recv_shapes
|
|
761
795
|
self.send_shapes = send_shapes
|
|
762
796
|
self.recv_type = recv_type
|
|
797
|
+
self.add_prim_attr('group', _get_group(group))
|
|
763
798
|
self.add_prim_attr('no_eliminate', True)
|
|
764
799
|
|
|
765
800
|
def __call__(self, tensor):
|
|
@@ -779,13 +814,13 @@ class AlltoAll(PrimitiveWithInfer):
|
|
|
779
814
|
Note:
|
|
780
815
|
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
|
|
781
816
|
in the same subnet, please check the `details \
|
|
782
|
-
<https://www.mindspore.cn/
|
|
817
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
|
|
783
818
|
|
|
784
819
|
Args:
|
|
785
820
|
split_count (int): On each process, divide blocks into split_count number.
|
|
786
821
|
split_dim (int): On each process, split blocks along the split_dim.
|
|
787
822
|
concat_dim (int): On each process, gather the received blocks along the concat_dimension.
|
|
788
|
-
group (str): The communication group to work on. Default:
|
|
823
|
+
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
|
|
789
824
|
|
|
790
825
|
Inputs:
|
|
791
826
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
@@ -809,12 +844,15 @@ class AlltoAll(PrimitiveWithInfer):
|
|
|
809
844
|
Before running the following examples, you need to configure the communication environment variables.
|
|
810
845
|
|
|
811
846
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
812
|
-
Please see the `
|
|
813
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
847
|
+
Please see the `rank table Startup
|
|
848
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
814
849
|
for more details.
|
|
815
850
|
|
|
816
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
817
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
851
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
852
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
853
|
+
|
|
854
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
855
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
818
856
|
|
|
819
857
|
This example should be run with 8 devices.
|
|
820
858
|
|
|
@@ -834,7 +872,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|
|
834
872
|
... out = self.alltoall(x)
|
|
835
873
|
... return out
|
|
836
874
|
...
|
|
837
|
-
>>> ms.set_context(mode=ms.GRAPH_MODE
|
|
875
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
838
876
|
>>> init()
|
|
839
877
|
>>> net = Net()
|
|
840
878
|
>>> rank_id = int(os.getenv("RANK_ID"))
|
|
@@ -842,6 +880,11 @@ class AlltoAll(PrimitiveWithInfer):
|
|
|
842
880
|
>>> output = net(input_x)
|
|
843
881
|
>>> print(output)
|
|
844
882
|
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
|
|
883
|
+
|
|
884
|
+
Tutorial Examples:
|
|
885
|
+
- `Distributed Set Communication Primitives - AlltoAll
|
|
886
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#alltoall>`_
|
|
887
|
+
|
|
845
888
|
"""
|
|
846
889
|
|
|
847
890
|
@prim_attr_register
|
|
@@ -882,15 +925,13 @@ class NeighborExchangeV2(Primitive):
|
|
|
882
925
|
NeighborExchangeV2 is a collective communication operation.
|
|
883
926
|
|
|
884
927
|
NeighborExchangeV2 sends data from the local rank to ranks in the `send_rank_ids`,
|
|
885
|
-
as while receive data from `recv_rank_ids`. Please refer to
|
|
886
|
-
|
|
887
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/communicate_ops.html#neighborexchangev2>`_
|
|
888
|
-
to learn about how the data is exchanged between neighborhood devices.
|
|
928
|
+
as while receive data from `recv_rank_ids`. Please refer to the tutorial examples
|
|
929
|
+
below to learn about how the data is exchanged between neighborhood devices.
|
|
889
930
|
|
|
890
931
|
Note:
|
|
891
932
|
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
|
|
892
933
|
in the same subnet, please check the `details \
|
|
893
|
-
<https://www.mindspore.cn/
|
|
934
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#notes>`_.
|
|
894
935
|
|
|
895
936
|
Args:
|
|
896
937
|
send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
|
|
@@ -902,8 +943,8 @@ class NeighborExchangeV2(Primitive):
|
|
|
902
943
|
recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
|
|
903
944
|
[recv_top, recv_bottom, recv_left, recv_right].
|
|
904
945
|
data_format (str): Data format, only support NCHW now.
|
|
905
|
-
group (str, optional): The communication group to work on. Default:
|
|
906
|
-
"hccl_world_group" in Ascend, and "nccl_world_group" in GPU.
|
|
946
|
+
group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
|
|
947
|
+
means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
|
|
907
948
|
|
|
908
949
|
Inputs:
|
|
909
950
|
- **input_x** (Tensor) - The Tensor before being exchanged. It has a shape of :math:`(N, C, H, W)`.
|
|
@@ -927,42 +968,68 @@ class NeighborExchangeV2(Primitive):
|
|
|
927
968
|
Before running the following examples, you need to configure the communication environment variables.
|
|
928
969
|
|
|
929
970
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
930
|
-
Please see the `
|
|
931
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
971
|
+
Please see the `rank table Startup
|
|
972
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
932
973
|
for more details.
|
|
933
974
|
|
|
934
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
935
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
975
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
976
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
977
|
+
|
|
978
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
979
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
936
980
|
|
|
937
981
|
This example should be run with 2 devices.
|
|
938
982
|
|
|
939
983
|
>>> import os
|
|
940
984
|
>>> import mindspore as ms
|
|
941
|
-
>>> from mindspore import Tensor
|
|
942
985
|
>>> from mindspore.communication import init
|
|
943
986
|
>>> import mindspore.nn as nn
|
|
944
987
|
>>> import mindspore.ops as ops
|
|
945
988
|
>>> import numpy as np
|
|
946
|
-
>>>
|
|
989
|
+
>>>
|
|
990
|
+
>>> class Net0(nn.Cell):
|
|
947
991
|
... def __init__(self):
|
|
948
|
-
... super(
|
|
949
|
-
... self.
|
|
950
|
-
...
|
|
951
|
-
...
|
|
952
|
-
...
|
|
953
|
-
... data_format="NCHW")
|
|
992
|
+
... super(Net0, self).__init__()
|
|
993
|
+
... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
|
994
|
+
... send_lens=[0, 1, 0, 0],
|
|
995
|
+
... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
|
996
|
+
... recv_lens=[0, 1, 0, 0], data_format="NCHW")
|
|
954
997
|
...
|
|
955
998
|
... def construct(self, x):
|
|
956
|
-
... out = self.
|
|
999
|
+
... out = self.neighbor_exchangev2(x)
|
|
957
1000
|
... return out
|
|
1001
|
+
>>>
|
|
1002
|
+
... class Net1(nn.Cell):
|
|
1003
|
+
... def __init__(self):
|
|
1004
|
+
... super(Net1, self).__init__()
|
|
1005
|
+
... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
|
|
1006
|
+
... send_lens=[1, 0, 0, 0],
|
|
1007
|
+
... recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
|
|
1008
|
+
... recv_lens=[1, 0, 0, 0], data_format="NCHW")
|
|
958
1009
|
...
|
|
959
|
-
|
|
1010
|
+
... def construct(self, x):
|
|
1011
|
+
... out = self.neighbor_exchangev2(x)
|
|
1012
|
+
... return out
|
|
1013
|
+
>>>
|
|
1014
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
960
1015
|
>>> init()
|
|
961
|
-
>>>
|
|
962
|
-
>>>
|
|
963
|
-
>>>
|
|
964
|
-
>>>
|
|
1016
|
+
>>> rank_id = int(os.getenv("RANK_ID"))
|
|
1017
|
+
>>> if (rank_id % 2 == 0):
|
|
1018
|
+
>>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
|
|
1019
|
+
>>> net = Net0()
|
|
1020
|
+
>>> output = net(input_x)
|
|
1021
|
+
>>> print(output)
|
|
1022
|
+
>>> else:
|
|
1023
|
+
>>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32)
|
|
1024
|
+
>>> net = Net1()
|
|
1025
|
+
>>> output = net(input_x)
|
|
1026
|
+
>>> print(output)
|
|
965
1027
|
[[[[1. 1.], [1. 1.], [2. 2.]]]]
|
|
1028
|
+
|
|
1029
|
+
Tutorial Examples:
|
|
1030
|
+
- `Distributed Set Communication Primitives - NeighborExchangeV2
|
|
1031
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/samples/ops/communicate_ops.html#neighborexchangev2>`_
|
|
1032
|
+
|
|
966
1033
|
"""
|
|
967
1034
|
|
|
968
1035
|
@prim_attr_register
|
|
@@ -976,6 +1043,15 @@ class NeighborExchangeV2(Primitive):
|
|
|
976
1043
|
self.format = data_format
|
|
977
1044
|
self.add_prim_attr('group', _get_group(group))
|
|
978
1045
|
self.add_prim_attr('no_eliminate', True)
|
|
1046
|
+
self.rank_size = get_group_size(_get_group(group))
|
|
1047
|
+
for rank_id in send_rank_ids:
|
|
1048
|
+
if rank_id != -1:
|
|
1049
|
+
validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
|
|
1050
|
+
"rank_id in send_rank_ids")
|
|
1051
|
+
for rank_id in recv_rank_ids:
|
|
1052
|
+
if rank_id != -1:
|
|
1053
|
+
validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
|
|
1054
|
+
"rank_id in recv_rank_ids")
|
|
979
1055
|
|
|
980
1056
|
def __call__(self, tensor):
|
|
981
1057
|
raise NotImplementedError
|
|
@@ -987,9 +1063,9 @@ class _MirrorOperator(PrimitiveWithInfer):
|
|
|
987
1063
|
internal use of parallel modules and cannot be called by users.
|
|
988
1064
|
|
|
989
1065
|
Args:
|
|
990
|
-
group (str): The communication group to work on. Default: None.
|
|
991
|
-
dev_num (int): The device number of the group. Default: None.
|
|
992
|
-
mean_flag (bool): Whether use mean in backward. Default: None.
|
|
1066
|
+
group (str): The communication group to work on. Default: ``None`` .
|
|
1067
|
+
dev_num (int): The device number of the group. Default: ``None`` .
|
|
1068
|
+
mean_flag (bool): Whether use mean in backward. Default: ``None`` .
|
|
993
1069
|
"""
|
|
994
1070
|
|
|
995
1071
|
@prim_attr_register
|
|
@@ -1017,10 +1093,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
|
|
|
1017
1093
|
internal use of parallel modules and cannot be called by users.
|
|
1018
1094
|
|
|
1019
1095
|
Args:
|
|
1020
|
-
group (str): The communication group to work on. Default: None.
|
|
1021
|
-
dev_num (int): The device number of the group. Default: None.
|
|
1022
|
-
mean_flag (bool): Whether use mean in backward. Default: None.
|
|
1023
|
-
grad_accumulation_step (int): The grad accumulation step. Default: None.
|
|
1096
|
+
group (str): The communication group to work on. Default: ``None`` .
|
|
1097
|
+
dev_num (int): The device number of the group. Default: ``None`` .
|
|
1098
|
+
mean_flag (bool): Whether use mean in backward. Default: ``None`` .
|
|
1099
|
+
grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
|
|
1024
1100
|
"""
|
|
1025
1101
|
|
|
1026
1102
|
@prim_attr_register
|
|
@@ -1176,9 +1252,9 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
|
|
|
1176
1252
|
internal use of parallel modules and cannot be called by users.
|
|
1177
1253
|
|
|
1178
1254
|
Args:
|
|
1179
|
-
group (str): The communication group to work on. Default: None.
|
|
1180
|
-
dev_num (int): The device number of the group. Default: None.
|
|
1181
|
-
mean_flag (bool): Whether use mean in backward. Default: None.
|
|
1255
|
+
group (str): The communication group to work on. Default: ``None`` .
|
|
1256
|
+
dev_num (int): The device number of the group. Default: ``None`` .
|
|
1257
|
+
mean_flag (bool): Whether use mean in backward. Default: ``None`` .
|
|
1182
1258
|
"""
|
|
1183
1259
|
|
|
1184
1260
|
@prim_attr_register
|
|
@@ -25,8 +25,8 @@ class GeSwitch(PrimitiveWithInfer):
|
|
|
25
25
|
"""
|
|
26
26
|
Adds control switch to data.
|
|
27
27
|
|
|
28
|
-
Switch data flows into
|
|
29
|
-
the
|
|
28
|
+
Switch data flows into ``False`` or ``True`` branch depending on the condition. If the condition is ``True`` ,
|
|
29
|
+
the ``True`` branch will be activated, or vise verse.
|
|
30
30
|
|
|
31
31
|
Inputs:
|
|
32
32
|
- **data** (Union[Tensor, Number]) - The data to be used for switch control.
|
|
@@ -81,7 +81,7 @@ class GeSwitch(PrimitiveWithInfer):
|
|
|
81
81
|
|
|
82
82
|
def infer_dtype(self, data_type, pred_type):
|
|
83
83
|
validator.check_subclass(
|
|
84
|
-
"data", data_type, (mstype.
|
|
84
|
+
"data", data_type, (mstype.tensor_type,) + mstype.number_type, self.name)
|
|
85
85
|
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
|
|
86
86
|
return data_type, data_type
|
|
87
87
|
|