mindspore 2.3.0__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +0 -1512
- mindspore/__init__.py +1 -2
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +25 -5
- mindspore/_extends/graph_kernel/model/graph_parallel.py +1 -1
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +0 -29
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +5 -21
- mindspore/_extends/parse/resources.py +7 -5
- mindspore/_extends/parse/standard_method.py +59 -40
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -26
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/less_batch_normalization.py +6 -9
- mindspore/common/__init__.py +1 -8
- mindspore/common/_register_for_tensor.py +9 -8
- mindspore/common/api.py +65 -275
- mindspore/common/dtype.py +4 -8
- mindspore/common/dump.py +5 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/lazy_inline.py +2 -14
- mindspore/common/parameter.py +15 -14
- mindspore/common/recompute.py +5 -20
- mindspore/common/sparse_tensor.py +6 -21
- mindspore/common/tensor.py +52 -100
- mindspore/communication/__init__.py +11 -6
- mindspore/communication/management.py +94 -92
- mindspore/context.py +18 -180
- mindspore/dataset/engine/datasets.py +46 -69
- mindspore/dataset/engine/datasets_user_defined.py +53 -72
- mindspore/dataset/engine/datasets_vision.py +2 -2
- mindspore/dataset/engine/queue.py +38 -56
- mindspore/dataset/engine/validators.py +5 -11
- mindspore/dataset/vision/__init__.py +5 -5
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +46 -591
- mindspore/dataset/vision/utils.py +1 -121
- mindspore/dataset/vision/validators.py +3 -9
- mindspore/hal/__init__.py +1 -7
- mindspore/hal/device.py +1 -1
- mindspore/include/api/model.h +0 -3
- mindspore/include/dataset/vision.h +2 -54
- mindspore/include/mindapi/base/types.h +0 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +0 -35
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +0 -72
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/{aclnn_all_finite.h → aclnn_add_custom.h} +11 -9
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +12 -184
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +5 -4
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +286 -275
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/add_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/backend_param.h +0 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/cast/cast_tiling.h +45 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/compare/compare_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_impl.h +4 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_tiling.h +4 -11
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/kernel/flash_attention_score_mix_hwsync.h +0 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_kernel.h +0 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_rtbackend.h +75 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/kernel/matmul.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +3 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_common_tiling.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_info.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_data.h +3 -36
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/kernel/matmul_stridedslice_fusion.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +4 -22
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +2 -16
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +3 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_impl.h +4 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +4 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/attention_param.h +2 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_qkv_param.h +4 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +12 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/backend.h +2 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/elewise_utils.h +1 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +0 -17
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/math.h +7 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/filewriter.py +2 -2
- mindspore/mint/__init__.py +40 -720
- mindspore/mint/nn/__init__.py +7 -89
- mindspore/mint/nn/functional.py +16 -165
- mindspore/mint/optim/adamw.py +16 -15
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +98 -97
- mindspore/nn/extend/basic.py +2 -2
- mindspore/nn/extend/embedding.py +1 -1
- mindspore/nn/extend/layer/normalization.py +5 -7
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +3 -4
- mindspore/nn/layer/basic.py +16 -79
- mindspore/nn/layer/conv.py +8 -17
- mindspore/nn/layer/embedding.py +4 -1
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +1 -1
- mindspore/nn/layer/pooling.py +0 -5
- mindspore/nn/layer/rnn_cells.py +2 -2
- mindspore/nn/loss/loss.py +19 -19
- mindspore/nn/optim/adasum.py +1 -1
- mindspore/nn/optim/sgd.py +2 -3
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/wrap/cell_wrapper.py +1 -25
- mindspore/nn/wrap/loss_scale.py +1 -24
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/math_ops.py +8 -8
- mindspore/ops/__init__.py +1 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -75
- mindspore/ops/_vmap/vmap_array_ops.py +0 -27
- mindspore/ops/_vmap/vmap_math_ops.py +1 -29
- mindspore/ops/_vmap/vmap_nn_ops.py +18 -19
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +8 -34
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +9 -2
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -26
- mindspore/ops/auto_generate/gen_extend_func.py +27 -603
- mindspore/ops/auto_generate/gen_ops_def.py +203 -993
- mindspore/ops/auto_generate/gen_ops_prim.py +402 -1946
- mindspore/ops/auto_generate/pyboost_inner_prim.py +20 -90
- mindspore/ops/composite/base.py +6 -3
- mindspore/ops/composite/math_ops.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +17 -24
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/extend/__init__.py +3 -2
- mindspore/ops/extend/array_func.py +51 -10
- mindspore/ops/extend/nn_func.py +78 -2
- mindspore/ops/function/__init__.py +13 -8
- mindspore/ops/function/array_func.py +179 -455
- mindspore/ops/function/clip_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +3 -3
- mindspore/ops/function/math_func.py +103 -117
- mindspore/ops/function/nn_func.py +163 -275
- mindspore/ops/function/other_func.py +2 -2
- mindspore/ops/function/random_func.py +69 -202
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/functional.py +327 -332
- mindspore/ops/operations/__init__.py +3 -13
- mindspore/ops/operations/_grad_ops.py +27 -3
- mindspore/ops/operations/_inner_ops.py +356 -53
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +65 -82
- mindspore/ops/operations/comm_ops.py +93 -784
- mindspore/ops/operations/custom_ops.py +28 -51
- mindspore/ops/operations/debug_ops.py +4 -4
- mindspore/ops/operations/inner_ops.py +2 -2
- mindspore/ops/operations/manually_defined/ops_def.py +4 -304
- mindspore/ops/operations/math_ops.py +50 -3
- mindspore/ops/operations/nn_ops.py +247 -14
- mindspore/ops/operations/other_ops.py +3 -3
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +1 -1
- mindspore/ops/primitive.py +8 -9
- mindspore/ops/silent_check.py +5 -5
- mindspore/ops_generate/arg_dtype_cast.py +9 -2
- mindspore/ops_generate/arg_handler.py +0 -26
- mindspore/ops_generate/gen_aclnn_implement.py +4 -1
- mindspore/ops_generate/gen_ops.py +4 -26
- mindspore/ops_generate/gen_pyboost_func.py +12 -41
- mindspore/ops_generate/gen_utils.py +0 -21
- mindspore/ops_generate/pyboost_utils.py +2 -7
- mindspore/ops_generate/template.py +0 -1
- mindspore/parallel/_auto_parallel_context.py +1 -21
- mindspore/parallel/_tensor.py +5 -0
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +1 -15
- mindspore/parallel/algo_parameter_config.py +3 -1
- mindspore/parallel/checkpoint_transform.py +9 -12
- mindspore/parallel/cluster/process_entity/_api.py +29 -28
- mindspore/parallel/cluster/process_entity/_utils.py +3 -13
- mindspore/parallel/cluster/run.py +16 -13
- mindspore/parallel/parameter_broadcast.py +2 -2
- mindspore/parallel/shard.py +17 -31
- mindspore/profiler/__init__.py +2 -3
- mindspore/profiler/common/util.py +2 -107
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +21 -8
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -82
- mindspore/profiler/parser/ascend_analysis/function_event.py +28 -43
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +27 -49
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +10 -15
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +20 -25
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +5 -5
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +1 -10
- mindspore/profiler/parser/ascend_hccl_generator.py +1 -4
- mindspore/profiler/parser/ascend_msprof_exporter.py +22 -43
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -7
- mindspore/profiler/parser/minddata_parser.py +3 -72
- mindspore/profiler/profiling.py +59 -176
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/common/namespace.py +5 -5
- mindspore/rewrite/parsers/assign_parser.py +0 -2
- mindspore/rewrite/parsers/class_def_parser.py +4 -8
- mindspore/run_check/_check_version.py +1 -1
- mindspore/scipy/fft.py +3 -1
- mindspore/scipy/linalg.py +3 -2
- mindspore/scipy/ops.py +3 -5
- mindspore/scipy/optimize/__init__.py +2 -2
- mindspore/train/__init__.py +4 -4
- mindspore/train/anf_ir_pb2.py +2 -8
- mindspore/train/callback/__init__.py +2 -5
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_checkpoint.py +16 -104
- mindspore/train/callback/_landscape.py +1 -1
- mindspore/train/callback/_time_monitor.py +1 -1
- mindspore/train/data_sink.py +4 -5
- mindspore/train/dataset_helper.py +20 -45
- mindspore/train/model.py +38 -266
- mindspore/train/serialization.py +105 -256
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +303 -420
- mindspore/_extends/pijit/__init__.py +0 -23
- mindspore/_extends/pijit/pijit_func_white_list.py +0 -343
- mindspore/common/file_system.py +0 -48
- mindspore/common/generator.py +0 -260
- mindspore/common/no_inline.py +0 -54
- mindspore/common/np_dtype.py +0 -25
- mindspore/communication/comm_func.py +0 -1140
- mindspore/hal/memory.py +0 -326
- mindspore/lib/libavcodec.so.59 +0 -0
- mindspore/lib/libavdevice.so.59 +0 -0
- mindspore/lib/libavfilter.so.8 +0 -0
- mindspore/lib/libavformat.so.59 +0 -0
- mindspore/lib/libavutil.so.57 +0 -0
- mindspore/lib/libmindspore_np_dtype.so +0 -0
- mindspore/lib/libswresample.so.4 +0 -0
- mindspore/lib/libswscale.so.6 +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.cpp +0 -326
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.py +0 -180
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +0 -109
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +0 -38
- mindspore/lib/plugin/ascend/custom_compiler/OWNERS +0 -12
- mindspore/lib/plugin/ascend/custom_compiler/setup.py +0 -255
- mindspore/lib/plugin/ascend/custom_compiler/start.sh +0 -26
- mindspore/lib/plugin/ascend/custom_compiler/template.json +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/base_type.h +0 -133
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_creator.h +0 -32
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_param.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/tiling_info.h +0 -60
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/kernel_register.h +0 -37
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/platform_configs.h +0 -89
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/rt_funcs.h +0 -135
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/add_op.h +0 -34
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_backoff_base.h +0 -62
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_elewise_op.h +0 -33
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_ops.h +0 -88
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_pa_op.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/cast_op.h +0 -52
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/matmul_op.h +0 -95
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/asd_utils.h +0 -84
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/comm_utils.h +0 -61
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp32.h +0 -224
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/and_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/div_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_impl.h +0 -48
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_tiling.h +0 -25
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/and_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/div_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_base.h +0 -260
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_kernel.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/max_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/min_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/mul_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/or_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/max_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/min_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/mul_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/or_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/abs_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_impl.h +0 -47
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_tiling.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/exp_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/abs_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_base.h +0 -148
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_kernel.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/exp_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/ln_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/not_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/reciprocal_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/relu_kernel.h +0 -55
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/rsqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/sqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/ln_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/not_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/reciprocal_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/relu_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/rsqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/sqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_impl.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_tiling.h +0 -187
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul.h +0 -245
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_interface.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_utils.h +0 -111
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/tiling_data.h +0 -54
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/compare_param.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/elewise_param.h +0 -41
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/grouped_matmul_param.h +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/profiling_util.h +0 -364
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_utils.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_creator.h +0 -39
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_registry.h +0 -114
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/utils.h +0 -98
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/mint/linalg/__init__.py +0 -22
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/nn/layer/embedding_service_layer.py +0 -393
- mindspore/ops/function/reshard_func.py +0 -102
- mindspore/ops/operations/_infer_ops.py +0 -19
- mindspore/ops/operations/reshard_ops.py +0 -53
- mindspore/profiler/common/process_pool.py +0 -41
- mindspore/profiler/common/singleton.py +0 -28
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/train/callback/_cluster_monitor.py +0 -201
- mindspore/train/callback/_flops_collector.py +0 -238
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -50,12 +50,12 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT
|
|
|
50
50
|
ScatterNdMul, SegmentMean, SegmentProd, SegmentSum, SegmentMax, SegmentMin, Tril, Triu,
|
|
51
51
|
UniqueConsecutive, UnravelIndex, FillV2, CountNonZero, TensorScatterElements, IndexPut,
|
|
52
52
|
MaskedScatter)
|
|
53
|
-
from .comm_ops import (AllGather, AllReduce,
|
|
54
|
-
Broadcast,
|
|
53
|
+
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
|
|
54
|
+
Broadcast,
|
|
55
55
|
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
|
56
56
|
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
|
57
57
|
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather,
|
|
58
|
-
_VirtualPipelineEnd
|
|
58
|
+
_VirtualPipelineEnd)
|
|
59
59
|
from .control_ops import GeSwitch, Merge
|
|
60
60
|
from .custom_ops import (Custom)
|
|
61
61
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
|
@@ -136,7 +136,6 @@ from ..deprecated import (identity, DropoutDoMask, MaxPoolWithArgmax, DropoutGen
|
|
|
136
136
|
TensorAdd, InplaceUpdate, ScatterNonAliasingAdd,
|
|
137
137
|
BatchToSpaceND, Unpack, GatherV2, DynamicShape, ScalarToArray, Pack)
|
|
138
138
|
from .manually_defined._inner import ScalarCast
|
|
139
|
-
from .reshard_ops import (Reshard)
|
|
140
139
|
|
|
141
140
|
__all__ = [
|
|
142
141
|
'HSVToRGB',
|
|
@@ -392,11 +391,9 @@ __all__ = [
|
|
|
392
391
|
'UnsortedSegmentProd',
|
|
393
392
|
"AllGather",
|
|
394
393
|
"AllReduce",
|
|
395
|
-
"Reduce",
|
|
396
394
|
"_AllSwap",
|
|
397
395
|
"ReduceScatter",
|
|
398
396
|
"Broadcast",
|
|
399
|
-
"BatchISendIRecv",
|
|
400
397
|
"ReduceOp",
|
|
401
398
|
'ScalarCast',
|
|
402
399
|
'GetNext',
|
|
@@ -531,12 +528,6 @@ __all__ = [
|
|
|
531
528
|
"NeighborExchangeV2",
|
|
532
529
|
"NeighborExchange",
|
|
533
530
|
"AlltoAll",
|
|
534
|
-
"AlltoAllV",
|
|
535
|
-
"CollectiveGather",
|
|
536
|
-
"CollectiveScatter",
|
|
537
|
-
"Barrier",
|
|
538
|
-
"Send",
|
|
539
|
-
"Receive",
|
|
540
531
|
"Custom",
|
|
541
532
|
"LuSolve",
|
|
542
533
|
"CholeskyInverse",
|
|
@@ -704,7 +695,6 @@ __all__ = [
|
|
|
704
695
|
"ReshapeAndCache",
|
|
705
696
|
"ApplyRotaryPosEmb",
|
|
706
697
|
"RmsNorm",
|
|
707
|
-
"Reshard",
|
|
708
698
|
]
|
|
709
699
|
|
|
710
700
|
__custom__ = [
|
|
@@ -34,9 +34,8 @@ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad,
|
|
|
34
34
|
GatherDGradV2, ResizeBilinearGrad, ResizeLinear1DGrad, ResizeNearestNeighborV2Grad,
|
|
35
35
|
SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
|
|
36
36
|
ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
|
|
37
|
-
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad,
|
|
38
|
-
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad
|
|
39
|
-
BinaryCrossEntropyGrad)
|
|
37
|
+
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad,
|
|
38
|
+
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad)
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
class SparseFillEmptyRowsGrad(Primitive):
|
|
@@ -98,6 +97,14 @@ class KLDivLossGrad(Primitive):
|
|
|
98
97
|
self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
|
|
99
98
|
|
|
100
99
|
|
|
100
|
+
class BinaryCrossEntropyGrad(Primitive):
|
|
101
|
+
"""Computes gradients for `BinaryCrossEntropy` operation."""
|
|
102
|
+
|
|
103
|
+
@prim_attr_register
|
|
104
|
+
def __init__(self, reduction='mean'):
|
|
105
|
+
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
|
106
|
+
|
|
107
|
+
|
|
101
108
|
class LuUnpackGrad(Primitive):
|
|
102
109
|
"""Computes gradients for `LuUnpack` operation."""
|
|
103
110
|
|
|
@@ -3050,3 +3057,20 @@ class WKVGrad(Primitive):
|
|
|
3050
3057
|
"""Initialize WKVGrad."""
|
|
3051
3058
|
self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
|
|
3052
3059
|
outputs=["gw", "gu", "gk", "gv"])
|
|
3060
|
+
|
|
3061
|
+
|
|
3062
|
+
class RmsNormGrad(Primitive):
|
|
3063
|
+
r"""
|
|
3064
|
+
Calculates the gradient of RmsNorm operation.
|
|
3065
|
+
.. warning::
|
|
3066
|
+
This is an experimental API that is subject to change or deletion.
|
|
3067
|
+
|
|
3068
|
+
Supported Platforms:
|
|
3069
|
+
``Ascend``
|
|
3070
|
+
"""
|
|
3071
|
+
|
|
3072
|
+
@prim_attr_register
|
|
3073
|
+
def __init__(self):
|
|
3074
|
+
"""Initialize RmsNormGrad."""
|
|
3075
|
+
self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
|
|
3076
|
+
outputs=["dx", "dgamma"])
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""Inner operators."""
|
|
17
|
+
# pylint: disable=unused-import
|
|
17
18
|
from types import FunctionType, MethodType
|
|
18
19
|
from collections.abc import Iterable
|
|
19
20
|
import os
|
|
@@ -24,6 +25,7 @@ from mindspore.common._stub_tensor import StubTensor
|
|
|
24
25
|
from mindspore.ops import composite as C
|
|
25
26
|
from mindspore.ops.operations.array_ops import Cast
|
|
26
27
|
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
28
|
+
from mindspore.ops.operations.comm_ops import ReduceOp
|
|
27
29
|
from mindspore.ops import signature as sig
|
|
28
30
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
29
31
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
@@ -37,8 +39,8 @@ from mindspore.communication.management import GlobalComm, get_rank, _get_group,
|
|
|
37
39
|
from mindspore.common.api import _pynative_executor
|
|
38
40
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
39
41
|
from mindspore import ops
|
|
40
|
-
from ..auto_generate import TensorCopySlices, SiLU, Cummin,
|
|
41
|
-
|
|
42
|
+
from ..auto_generate import TensorCopySlices, SiLU, Cummin, ExtractImagePatches, DecoderKVCache, PromptKVCache, \
|
|
43
|
+
ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
|
|
42
44
|
|
|
43
45
|
# Bit operation
|
|
44
46
|
bit_and = bit_and()
|
|
@@ -57,30 +59,6 @@ string_mul = Primitive("string_mul")
|
|
|
57
59
|
string_getitem = Primitive("string_getitem")
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
class Generator(Primitive):
|
|
61
|
-
r"""
|
|
62
|
-
Manage the state of random number generation.
|
|
63
|
-
|
|
64
|
-
Inputs:
|
|
65
|
-
- **cmd** (int) : operation to be executed.
|
|
66
|
-
- **inputs** (tuple[tensor]) : inputs for the operation.
|
|
67
|
-
|
|
68
|
-
Outputs:
|
|
69
|
-
- **seed** (Tensor): Seed for the random number generation algorithm.
|
|
70
|
-
- **offset** (Tensor): Offset of the random number sequence.
|
|
71
|
-
- **state** (Tensor): State tensor, can be used to restore current state.
|
|
72
|
-
"""
|
|
73
|
-
|
|
74
|
-
@prim_attr_register
|
|
75
|
-
def __init__(self):
|
|
76
|
-
self.add_prim_attr("side_effect_mem", True)
|
|
77
|
-
|
|
78
|
-
def __call__(self, cmd, inputs):
|
|
79
|
-
if cmd == 0: # step cmd
|
|
80
|
-
return inputs[0], inputs[1]
|
|
81
|
-
return super().__call__(cmd, inputs)
|
|
82
|
-
|
|
83
|
-
|
|
84
62
|
class Quant(PrimitiveWithInfer):
|
|
85
63
|
r"""
|
|
86
64
|
Returns the quantized value of input_x.
|
|
@@ -389,6 +367,229 @@ class MatrixDiagPart(PrimitiveWithInfer):
|
|
|
389
367
|
return out_shape
|
|
390
368
|
|
|
391
369
|
|
|
370
|
+
class Send(PrimitiveWithInfer):
|
|
371
|
+
"""
|
|
372
|
+
Send tensors from src_rank to the specified dest_rank.
|
|
373
|
+
|
|
374
|
+
Note:
|
|
375
|
+
Send and Receive must be used in combination and have same sr_tag.
|
|
376
|
+
Send must be used between servers.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
sr_tag (int): A required integer identifying the send/recv message tag. The message will
|
|
380
|
+
will be received by the Receive op with the same "sr_tag".
|
|
381
|
+
dest_rank (int): A required integer identifying the destination rank.
|
|
382
|
+
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
|
383
|
+
|
|
384
|
+
Inputs:
|
|
385
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
386
|
+
|
|
387
|
+
Examples:
|
|
388
|
+
>>> import mindspore.ops as ops
|
|
389
|
+
>>> import mindspore.nn as nn
|
|
390
|
+
>>> from mindspore.communication import init
|
|
391
|
+
>>> from mindspore import Tensor
|
|
392
|
+
>>> import numpy as np
|
|
393
|
+
>>>
|
|
394
|
+
>>> init()
|
|
395
|
+
>>> class Net(nn.Cell):
|
|
396
|
+
>>> def __init__(self):
|
|
397
|
+
>>> super(Net, self).__init__()
|
|
398
|
+
>>> self.depend = ops.Depend()
|
|
399
|
+
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
|
|
400
|
+
>>>
|
|
401
|
+
>>> def construct(self, x):
|
|
402
|
+
>>> out = self.depend(x, self.send(x))
|
|
403
|
+
>>> return out
|
|
404
|
+
>>>
|
|
405
|
+
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
406
|
+
>>> net = Net()
|
|
407
|
+
>>> output = net(input_)
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
@prim_attr_register
|
|
411
|
+
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
|
|
412
|
+
self.rank = dest_rank
|
|
413
|
+
self.sr_tag = sr_tag
|
|
414
|
+
self.group = group
|
|
415
|
+
self.add_prim_attr("no_eliminate", True)
|
|
416
|
+
|
|
417
|
+
def infer_shape(self, x_shape):
|
|
418
|
+
self.add_prim_attr("shape", x_shape)
|
|
419
|
+
return x_shape
|
|
420
|
+
|
|
421
|
+
def infer_dtype(self, x_dtype):
|
|
422
|
+
return x_dtype
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class Receive(PrimitiveWithInfer):
|
|
426
|
+
"""
|
|
427
|
+
Receive tensors from src_rank.
|
|
428
|
+
|
|
429
|
+
Note:
|
|
430
|
+
Send and Receive must be used in combination and have same sr_tag.
|
|
431
|
+
Receive must be used between servers.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
sr_tag (int): A required integer identifying the send/recv message tag. The message will
|
|
435
|
+
will be send by the Send op with the same "sr_tag".
|
|
436
|
+
src_rank (int): A required integer identifying the source rank.
|
|
437
|
+
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
|
438
|
+
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
|
439
|
+
int8, int16, int32, float16, float32.
|
|
440
|
+
group (str, optional): The communication group to work on.
|
|
441
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
442
|
+
|
|
443
|
+
Inputs:
|
|
444
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
445
|
+
|
|
446
|
+
Examples:
|
|
447
|
+
>>> import mindspore.ops as ops
|
|
448
|
+
>>> import mindspore.nn as nn
|
|
449
|
+
>>> from mindspore.communication import init
|
|
450
|
+
>>> from mindspore import Tensor
|
|
451
|
+
>>> import numpy as np
|
|
452
|
+
>>>
|
|
453
|
+
>>> init()
|
|
454
|
+
>>> class Net(nn.Cell):
|
|
455
|
+
>>> def __init__(self):
|
|
456
|
+
>>> super(Net, self).__init__()
|
|
457
|
+
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
|
|
458
|
+
>>> group="hccl_world_group")
|
|
459
|
+
>>>
|
|
460
|
+
>>> def construct(self):
|
|
461
|
+
>>> out = self.recv()
|
|
462
|
+
>>> return out
|
|
463
|
+
>>>
|
|
464
|
+
>>> net = Net()
|
|
465
|
+
>>> output = net()
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
@prim_attr_register
|
|
469
|
+
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
|
|
470
|
+
group_back=GlobalComm.WORLD_COMM_GROUP):
|
|
471
|
+
self.rank = src_rank
|
|
472
|
+
self.tag = sr_tag
|
|
473
|
+
self.shape = shape
|
|
474
|
+
self.dtype = dtype
|
|
475
|
+
self.group = group
|
|
476
|
+
self.add_prim_attr("no_eliminate", True)
|
|
477
|
+
valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16,
|
|
478
|
+
mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
|
479
|
+
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64]
|
|
480
|
+
args = {"dtype": dtype}
|
|
481
|
+
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
|
482
|
+
|
|
483
|
+
def infer_shape(self, x_shape=None):
|
|
484
|
+
return self.get_attr_dict()['shape']
|
|
485
|
+
|
|
486
|
+
def infer_dtype(self, x_dtype=None):
|
|
487
|
+
return self.get_attr_dict()['dtype']
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class Reduce(PrimitiveWithInfer):
|
|
491
|
+
"""
|
|
492
|
+
Reduces tensor across the processes in the specified communication group.
|
|
493
|
+
|
|
494
|
+
Note:
|
|
495
|
+
Only process with destination rank receives the reduced output.
|
|
496
|
+
Other processes only get a tensor with shape [1], which has no mathematical meaning.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
dest_rank (int): Specifies the rank of the process that receives the reduced output.
|
|
500
|
+
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
501
|
+
On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
|
|
502
|
+
group (str, optional): The communication group to work on.
|
|
503
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
504
|
+
|
|
505
|
+
Inputs:
|
|
506
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
507
|
+
|
|
508
|
+
Examples:
|
|
509
|
+
>>> import mindspore.ops as ops
|
|
510
|
+
>>> import mindspore.nn as nn
|
|
511
|
+
>>> from mindspore.communication import init
|
|
512
|
+
>>> from mindspore import Tensor
|
|
513
|
+
>>> import numpy as np
|
|
514
|
+
>>> # Launch 4 processes.
|
|
515
|
+
>>> init()
|
|
516
|
+
>>> class ReduceNet(nn.Cell):
|
|
517
|
+
>>> def __init__(self):
|
|
518
|
+
>>> super(Net, self).__init__()
|
|
519
|
+
>>> self.reduce = ops.Reduce(dest_rank=1)
|
|
520
|
+
>>>
|
|
521
|
+
>>> def construct(self, x):
|
|
522
|
+
>>> out = self.reduce(x)
|
|
523
|
+
>>> return out
|
|
524
|
+
>>> input = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
525
|
+
>>> net = ReduceNet()
|
|
526
|
+
>>> output = net(input)
|
|
527
|
+
>>> print(output)
|
|
528
|
+
Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
|
|
529
|
+
[4. 4. 4. 4. 4. 4. 4. 4.]],
|
|
530
|
+
Other proesses: [0.].
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
@prim_attr_register
|
|
534
|
+
def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
535
|
+
self.dest_rank = dest_rank
|
|
536
|
+
self.op = op
|
|
537
|
+
self.group = group
|
|
538
|
+
|
|
539
|
+
def infer_shape(self, x_shape):
|
|
540
|
+
# The process with dest_rank returns the reduced output.
|
|
541
|
+
# Other processes only gets a tensor with shape [1], which has no mathematical meaning.
|
|
542
|
+
if self.dest_rank == get_rank():
|
|
543
|
+
return x_shape
|
|
544
|
+
return [1]
|
|
545
|
+
|
|
546
|
+
def infer_dtype(self, x_dtype):
|
|
547
|
+
return x_dtype
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class Barrier(PrimitiveWithInfer):
|
|
551
|
+
"""
|
|
552
|
+
Synchronizes all processes in the specified group.
|
|
553
|
+
|
|
554
|
+
Note:
|
|
555
|
+
After calling this collective operator,
|
|
556
|
+
this process will be blocked until all other processes in the group call this operator.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
group (str, optional): The communication group to work on.
|
|
560
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
561
|
+
|
|
562
|
+
Examples:
|
|
563
|
+
>>> import mindspore.ops as ops
|
|
564
|
+
>>> import mindspore.nn as nn
|
|
565
|
+
>>> from mindspore.communication import init
|
|
566
|
+
>>> from mindspore import Tensor
|
|
567
|
+
>>> import numpy as np
|
|
568
|
+
>>> # Launch 4 processes.
|
|
569
|
+
>>> init()
|
|
570
|
+
>>> class BarrierNet(nn.Cell):
|
|
571
|
+
>>> def __init__(self):
|
|
572
|
+
>>> super(Net, self).__init__()
|
|
573
|
+
>>> self.barrier = ops.Barrier()
|
|
574
|
+
>>>
|
|
575
|
+
>>> def construct(self):
|
|
576
|
+
>>> self.barrier()
|
|
577
|
+
>>> net = BarrierNet()
|
|
578
|
+
>>> net()
|
|
579
|
+
"""
|
|
580
|
+
|
|
581
|
+
@prim_attr_register
|
|
582
|
+
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
583
|
+
self.group = group
|
|
584
|
+
self.add_prim_attr("side_effect_mem", True)
|
|
585
|
+
|
|
586
|
+
def infer_shape(self):
|
|
587
|
+
return [1]
|
|
588
|
+
|
|
589
|
+
def infer_dtype(self):
|
|
590
|
+
return mstype.float32
|
|
591
|
+
|
|
592
|
+
|
|
392
593
|
class MatrixSetDiag(PrimitiveWithInfer):
|
|
393
594
|
r"""
|
|
394
595
|
Modifies the batched diagonal part of a batched tensor.
|
|
@@ -2469,15 +2670,67 @@ class FFN(Primitive):
|
|
|
2469
2670
|
def __init__(self, activation, inner_precise):
|
|
2470
2671
|
"""Initialize FFN."""
|
|
2471
2672
|
self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
|
|
2472
|
-
"bias2", "scale", "offset", "deq_scale1", "deq_scale2",
|
|
2473
|
-
"antiquant_scale1", "antiquant_scale2",
|
|
2474
|
-
"antiquant_offset1", "antiquant_offset2"],
|
|
2673
|
+
"bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
|
|
2475
2674
|
outputs=["y"])
|
|
2476
2675
|
cls_name = self.name
|
|
2477
2676
|
validator.check_value_type("activation", activation, [str], cls_name)
|
|
2478
2677
|
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2479
2678
|
|
|
2480
2679
|
|
|
2680
|
+
class CollectiveScatter(Primitive):
|
|
2681
|
+
"""
|
|
2682
|
+
Scatter tensor across the processes in the specified communication group.
|
|
2683
|
+
|
|
2684
|
+
Note:
|
|
2685
|
+
Collect communication domain scatter operation interface.
|
|
2686
|
+
Distribute the tensors of the specified rank evenly to each rank.
|
|
2687
|
+
|
|
2688
|
+
Args:
|
|
2689
|
+
src_rank (int): Specifies the rank of the process that send the tensor.
|
|
2690
|
+
group (str, optional): The communication group to work on.
|
|
2691
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
2692
|
+
|
|
2693
|
+
Inputs:
|
|
2694
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
2695
|
+
|
|
2696
|
+
Supported Platforms:
|
|
2697
|
+
``Ascend``
|
|
2698
|
+
|
|
2699
|
+
Examples:
|
|
2700
|
+
>>> import mindspore.nn as nn
|
|
2701
|
+
>>> import numpy as np
|
|
2702
|
+
>>> from mindspore import Tensor
|
|
2703
|
+
>>> from mindspore.communication.management import init, get_rank
|
|
2704
|
+
>>> from mindspore.ops.operations import _inner_ops as inner_p
|
|
2705
|
+
>>> # Launch 4 processes.
|
|
2706
|
+
>>> init()
|
|
2707
|
+
>>> class CollectiveScatterNet(nn.Cell):
|
|
2708
|
+
>>> def __init__(self):
|
|
2709
|
+
>>> super(self).__init__()
|
|
2710
|
+
>>> self.collective_scatter = inner_p.CollectiveScatter(src_rank=0)
|
|
2711
|
+
>>>
|
|
2712
|
+
>>> def construct(self, x):
|
|
2713
|
+
>>> out = self.collective_scatter(x)
|
|
2714
|
+
>>> return out
|
|
2715
|
+
>>> input = Tensor(np.ones([8, 8]).astype(np.float32))
|
|
2716
|
+
>>> net = CollectiveScatterNet()
|
|
2717
|
+
>>> output = net(input)
|
|
2718
|
+
>>> print(output)
|
|
2719
|
+
Process with all rank : [[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
2720
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]],
|
|
2721
|
+
"""
|
|
2722
|
+
|
|
2723
|
+
@prim_attr_register
|
|
2724
|
+
def __init__(self, src_rank=0, group=GlobalComm.WORLD_COMM_GROUP):
|
|
2725
|
+
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
2726
|
+
self.rank_size = get_group_size(_get_group(group))
|
|
2727
|
+
self.src_rank = src_rank
|
|
2728
|
+
|
|
2729
|
+
self.add_prim_attr('src_rank', self.src_rank)
|
|
2730
|
+
self.add_prim_attr('rank_size', self.rank_size)
|
|
2731
|
+
self.add_prim_attr('group', _get_group(group))
|
|
2732
|
+
|
|
2733
|
+
|
|
2481
2734
|
class _MirrorSilentCheck(PrimitiveWithInfer):
|
|
2482
2735
|
"""
|
|
2483
2736
|
The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
|
|
@@ -2532,36 +2785,86 @@ class _MirrorSilentCheck(PrimitiveWithInfer):
|
|
|
2532
2785
|
return x_dtype
|
|
2533
2786
|
|
|
2534
2787
|
|
|
2535
|
-
class
|
|
2536
|
-
"""
|
|
2537
|
-
Auto parallel virtual operator.
|
|
2788
|
+
class CollectiveGather(Primitive):
|
|
2538
2789
|
"""
|
|
2790
|
+
Gathers tensors from the specified communication group.
|
|
2539
2791
|
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
self.input_nums = input_nums
|
|
2792
|
+
Note:
|
|
2793
|
+
Implement the gatherer operation interface.
|
|
2794
|
+
Combine all tensors of the rank in order of rank, and then send the results to the target rank.
|
|
2544
2795
|
|
|
2545
|
-
|
|
2546
|
-
|
|
2796
|
+
Args:
|
|
2797
|
+
group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
|
|
2798
|
+
means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
|
|
2547
2799
|
|
|
2548
|
-
|
|
2549
|
-
|
|
2800
|
+
Inputs:
|
|
2801
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
2550
2802
|
|
|
2803
|
+
Outputs:
|
|
2804
|
+
Tensor. If the number of devices in the group is N,
|
|
2805
|
+
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
|
|
2551
2806
|
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2807
|
+
Raises:
|
|
2808
|
+
TypeError: If `group` is not a str.
|
|
2809
|
+
ValueError: If the local rank id of the calling process in the group
|
|
2810
|
+
is larger than the group's rank size.
|
|
2556
2811
|
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
"""Initialize _VirtualConverterBegin."""
|
|
2560
|
-
self.output_nums = output_nums
|
|
2812
|
+
Supported Platforms:
|
|
2813
|
+
``Ascend``
|
|
2561
2814
|
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2815
|
+
Examples:
|
|
2816
|
+
.. note::
|
|
2817
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2818
|
+
|
|
2819
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2820
|
+
Please see the `rank table Startup
|
|
2821
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
|
|
2822
|
+
for more details.
|
|
2823
|
+
|
|
2824
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
2825
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
|
|
2826
|
+
|
|
2827
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2828
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
|
|
2829
|
+
|
|
2830
|
+
This example should be run with 4 devices.
|
|
2565
2831
|
|
|
2566
|
-
|
|
2567
|
-
|
|
2832
|
+
>>> import numpy as np
|
|
2833
|
+
>>> import mindspore as ms
|
|
2834
|
+
>>> import mindspore.nn as nn
|
|
2835
|
+
>>> from mindspore.communication import init
|
|
2836
|
+
>>> from mindspore import Tensor
|
|
2837
|
+
>>> from mindspore.ops.operations import _inner_ops as inner_p
|
|
2838
|
+
>>>
|
|
2839
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2840
|
+
>>> init()
|
|
2841
|
+
>>> class Net(nn.Cell):
|
|
2842
|
+
... def __init__(self):
|
|
2843
|
+
... super(Net, self).__init__()
|
|
2844
|
+
... self.collective_gather = inner_p.CollectiveGather(dest_rank=0)
|
|
2845
|
+
...
|
|
2846
|
+
... def construct(self, x):
|
|
2847
|
+
... return self.collective_gather(x)
|
|
2848
|
+
...
|
|
2849
|
+
>>> input_x = Tensor(np.ones([1, 8]).astype(np.float32))
|
|
2850
|
+
>>> net = Net()
|
|
2851
|
+
>>> output = net(input_x)
|
|
2852
|
+
>>> print(output)
|
|
2853
|
+
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
2854
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
2855
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
2856
|
+
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
2857
|
+
"""
|
|
2858
|
+
|
|
2859
|
+
@prim_attr_register
|
|
2860
|
+
def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
|
2861
|
+
"""Initialize Gather."""
|
|
2862
|
+
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
2863
|
+
self.rank_id = get_rank(_get_group(group))
|
|
2864
|
+
self.dest_rank = dest_rank
|
|
2865
|
+
self.rank_size = get_group_size(_get_group(group))
|
|
2866
|
+
validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name)
|
|
2867
|
+
self.add_prim_attr('rank_size', self.rank_size)
|
|
2868
|
+
self.add_prim_attr('group', _get_group(group))
|
|
2869
|
+
self.add_prim_attr('dest_rank', self.dest_rank)
|
|
2870
|
+
self.add_prim_attr('rank_id', self.rank_id)
|
|
@@ -1132,7 +1132,7 @@ class MuxSend(PrimitiveWithInfer):
|
|
|
1132
1132
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1133
1133
|
|
|
1134
1134
|
Examples:
|
|
1135
|
-
>>>
|
|
1135
|
+
>>> import mindspore.ops as ops
|
|
1136
1136
|
>>> import mindspore.nn as nn
|
|
1137
1137
|
>>> from mindspore.communication import init
|
|
1138
1138
|
>>> from mindspore import Tensor
|
|
@@ -1190,7 +1190,7 @@ class MuxReceive(PrimitiveWithInfer):
|
|
|
1190
1190
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
1191
1191
|
|
|
1192
1192
|
Examples:
|
|
1193
|
-
>>>
|
|
1193
|
+
>>> import mindspore.ops as ops
|
|
1194
1194
|
>>> import mindspore.nn as nn
|
|
1195
1195
|
>>> from mindspore.communication import init
|
|
1196
1196
|
>>> from mindspore import Tensor
|
|
@@ -46,7 +46,7 @@ class TensorArray(PrimitiveWithInfer):
|
|
|
46
46
|
|
|
47
47
|
Examples:
|
|
48
48
|
>>> import mindspore
|
|
49
|
-
>>>
|
|
49
|
+
>>> import mindspore.ops as ops
|
|
50
50
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
51
51
|
>>> handle = create_op()
|
|
52
52
|
>>> print(handle)
|
|
@@ -90,7 +90,7 @@ class TensorArrayWrite(PrimitiveWithInfer):
|
|
|
90
90
|
|
|
91
91
|
Examples:
|
|
92
92
|
>>> import mindspore
|
|
93
|
-
>>>
|
|
93
|
+
>>> import mindspore.ops as ops
|
|
94
94
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
95
95
|
>>> handle = create_op()
|
|
96
96
|
>>> write_op = ops.TensorArrayWrite()
|
|
@@ -133,7 +133,7 @@ class TensorArrayRead(PrimitiveWithInfer):
|
|
|
133
133
|
|
|
134
134
|
Examples:
|
|
135
135
|
>>> import mindspore
|
|
136
|
-
>>>
|
|
136
|
+
>>> import mindspore.ops as ops
|
|
137
137
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
138
138
|
>>> handle = create_op()
|
|
139
139
|
>>> write_op = ops.TensorArrayWrite()
|
|
@@ -179,7 +179,7 @@ class TensorArrayClose(PrimitiveWithInfer):
|
|
|
179
179
|
|
|
180
180
|
Examples:
|
|
181
181
|
>>> import mindspore
|
|
182
|
-
>>>
|
|
182
|
+
>>> import mindspore.ops as ops
|
|
183
183
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
184
184
|
>>> handle = create_op()
|
|
185
185
|
>>> close_op = ops.TensorArrayClose()
|
|
@@ -215,7 +215,7 @@ class TensorArrayClear(PrimitiveWithInfer):
|
|
|
215
215
|
|
|
216
216
|
Examples:
|
|
217
217
|
>>> import mindspore
|
|
218
|
-
>>>
|
|
218
|
+
>>> import mindspore.ops as ops
|
|
219
219
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
220
220
|
>>> handle = create_op()
|
|
221
221
|
>>> clear_op = ops.TensorArrayClear()
|
|
@@ -255,7 +255,7 @@ class TensorArrayStack(Primitive):
|
|
|
255
255
|
|
|
256
256
|
Examples:
|
|
257
257
|
>>> import mindspore
|
|
258
|
-
>>>
|
|
258
|
+
>>> import mindspore.ops as ops
|
|
259
259
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
260
260
|
>>> handle = create_op()
|
|
261
261
|
>>> write_op = ops.TensorArrayWrite()
|
|
@@ -295,7 +295,7 @@ class TensorArraySize(PrimitiveWithInfer):
|
|
|
295
295
|
|
|
296
296
|
Examples:
|
|
297
297
|
>>> import mindspore
|
|
298
|
-
>>>
|
|
298
|
+
>>> import mindspore.ops as ops
|
|
299
299
|
>>> create_op = ops.TensorArray(mindspore.int32, ())
|
|
300
300
|
>>> handle = create_op()
|
|
301
301
|
>>> size_op = ops.TensorArraySize()
|
|
@@ -333,7 +333,7 @@ class TensorArrayGather(PrimitiveWithInfer):
|
|
|
333
333
|
|
|
334
334
|
Examples:
|
|
335
335
|
>>> import mindspore
|
|
336
|
-
>>>
|
|
336
|
+
>>> import mindspore.ops as ops
|
|
337
337
|
>>> from mindspore import numpy as mnp
|
|
338
338
|
>>> create_op = ops.TensorArray(mindspore.float32, dynamic_size=False, element_shape=(8,))
|
|
339
339
|
>>> handle = create_op()
|