mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -42,7 +42,7 @@ from ..auto_generate import (
|
|
|
42
42
|
NonZero, ResizeNearestNeighbor, Identity, Split, CumSum, CumProd,
|
|
43
43
|
MaskedSelect, Cummax, Cummin, Argmin, Concat, UnsortedSegmentSum, UniqueConsecutive,
|
|
44
44
|
ScalarToTensor, Triu, BroadcastTo, StridedSlice, Select, TopkExt,
|
|
45
|
-
SearchSorted, Meshgrid, Squeeze, Slice, TransposeExtView)
|
|
45
|
+
SearchSorted, Meshgrid, Squeeze, Slice, TransposeExtView, MaskedScatter)
|
|
46
46
|
from .manually_defined import Rank, Shape, Tile, Cast, Ones, Zeros, TypeAs
|
|
47
47
|
from ..auto_generate import ArgMaxWithValue, ArgMinWithValue
|
|
48
48
|
from ..auto_generate import TensorScatterElements as TensorScatterElementsExt
|
|
@@ -1048,11 +1048,11 @@ class Fill(PrimitiveWithCheck):
|
|
|
1048
1048
|
self.init_prim_io_names(inputs=['type', 'shape', 'value'], outputs=['y'])
|
|
1049
1049
|
|
|
1050
1050
|
def __call__(self, dtype, dims, x):
|
|
1051
|
-
if dtype not in mstype.all_types
|
|
1051
|
+
if dtype not in mstype.all_types:
|
|
1052
1052
|
raise TypeError(
|
|
1053
1053
|
f"For \'{self.name}\', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', "
|
|
1054
1054
|
"'uint16', 'uint32', 'uint64','float16', 'float32', 'float64'], but got an invalid dtype!.")
|
|
1055
|
-
x_nptype = mstype.
|
|
1055
|
+
x_nptype = mstype._dtype_to_nptype(dtype) # pylint:disable=protected-access
|
|
1056
1056
|
if not isinstance(dims, Tensor) and not isinstance(dims, tuple):
|
|
1057
1057
|
raise TypeError(f"For \'{self.name}\', input[1] must be tensor.")
|
|
1058
1058
|
if not isinstance(x, Tensor) and not isinstance(x, float) and not isinstance(x, int):
|
|
@@ -1065,7 +1065,7 @@ class Fill(PrimitiveWithCheck):
|
|
|
1065
1065
|
return Tensor(ret, dtype=dtype)
|
|
1066
1066
|
|
|
1067
1067
|
def infer_value(self, dtype, dims, x):
|
|
1068
|
-
x_nptype = mstype.
|
|
1068
|
+
x_nptype = mstype._dtype_to_nptype(dtype) # pylint:disable=protected-access
|
|
1069
1069
|
if dims is not None and None not in dims and x is not None:
|
|
1070
1070
|
if isinstance(dims, Tensor):
|
|
1071
1071
|
dims = dims.asnumpy()
|
|
@@ -1157,6 +1157,7 @@ class FillV2(PrimitiveWithCheck):
|
|
|
1157
1157
|
init_func = Zero()
|
|
1158
1158
|
init_func.__enable_zero_dim__ = True
|
|
1159
1159
|
out = Tensor(shape=dims, dtype=x.dtype, init=init_func)
|
|
1160
|
+
out.init_data()
|
|
1160
1161
|
return out
|
|
1161
1162
|
return Tensor(np.full(dims, x.asnumpy()))
|
|
1162
1163
|
|
|
@@ -3974,52 +3975,6 @@ class RangeV2(Primitive):
|
|
|
3974
3975
|
validator.check_positive_int(maxlen, "maxlen", self.name)
|
|
3975
3976
|
|
|
3976
3977
|
|
|
3977
|
-
class MaskedScatter(Primitive):
|
|
3978
|
-
"""
|
|
3979
|
-
Updates the value in the input with value in `updates` according to the `mask`.
|
|
3980
|
-
|
|
3981
|
-
.. warning::
|
|
3982
|
-
This is an experimental API that is subject to change or deletion.
|
|
3983
|
-
|
|
3984
|
-
Inputs:
|
|
3985
|
-
- **x** (Tensor): The input Tensor to be updated.
|
|
3986
|
-
- **mask** (Tensor[bool]): The mask Tensor indicating which elements should be modified or replaced.
|
|
3987
|
-
The shapes of `mask` and `x` must be the same or broadcastable.
|
|
3988
|
-
- **updates** (Tensor): The values to scatter into the target tensor `x`. It has the same data type as `x`. The
|
|
3989
|
-
number of elements must be greater than or equal to the number of True's in `mask`.
|
|
3990
|
-
|
|
3991
|
-
Outputs:
|
|
3992
|
-
Tensor, with the same type and shape as `x`.
|
|
3993
|
-
|
|
3994
|
-
Raises:
|
|
3995
|
-
TypeError: If `x`, `mask` or `updates` is not a Tensor.
|
|
3996
|
-
TypeError: If data type of `x` is not be supported.
|
|
3997
|
-
TypeError: If dtype of `mask` is not bool.
|
|
3998
|
-
TypeError: If the dim of `x` less than the dim of `mask`.
|
|
3999
|
-
ValueError: If `mask` can not be broadcastable to `x`.
|
|
4000
|
-
ValueError: If the number of elements in `updates` is less than number of True's in `mask`.
|
|
4001
|
-
|
|
4002
|
-
Supported Platforms:
|
|
4003
|
-
``Ascend`` ``CPU``
|
|
4004
|
-
|
|
4005
|
-
Examples:
|
|
4006
|
-
>>> import mindspore
|
|
4007
|
-
>>> import numpy as np
|
|
4008
|
-
>>> from mindspore import Tensor, ops
|
|
4009
|
-
>>> input_x = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
|
|
4010
|
-
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
|
|
4011
|
-
>>> updates = Tensor(np.array([5., 6., 7.]), mindspore.float32)
|
|
4012
|
-
>>> output = ops.MaskedScatter()(input_x, mask, updates)
|
|
4013
|
-
>>> print(output)
|
|
4014
|
-
[5. 6. 3. 7.]
|
|
4015
|
-
"""
|
|
4016
|
-
|
|
4017
|
-
@prim_attr_register
|
|
4018
|
-
def __init__(self):
|
|
4019
|
-
"""Initialize MaskedScatter"""
|
|
4020
|
-
self.init_prim_io_names(inputs=['x', 'mask', 'updates'], outputs=['y'])
|
|
4021
|
-
|
|
4022
|
-
|
|
4023
3978
|
class _TensorScatterOp(PrimitiveWithInfer):
|
|
4024
3979
|
"""
|
|
4025
3980
|
Defines TensorScatter Base Operators
|
|
@@ -18,10 +18,9 @@
|
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
from __future__ import division
|
|
20
20
|
|
|
21
|
-
import os
|
|
22
21
|
from mindspore.common import Tensor
|
|
23
22
|
from mindspore import _checkparam as validator
|
|
24
|
-
from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
23
|
+
from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
25
24
|
from mindspore.common import dtype as mstype
|
|
26
25
|
from mindspore.ops.primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register
|
|
27
26
|
from mindspore.common.api import context
|
|
@@ -98,17 +97,6 @@ def check_collective_target_dtype(data_name, data_dtype, prim_name):
|
|
|
98
97
|
validator.check_tensor_dtype_valid(data_name, data_dtype, valid_dtype, prim_name)
|
|
99
98
|
|
|
100
99
|
|
|
101
|
-
def check_hcom_group_valid(group, prim_name=None):
|
|
102
|
-
"""Check if hcom group is valid."""
|
|
103
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
104
|
-
sim_level = os.getenv("MS_SIMULATION_LEVEL")
|
|
105
|
-
no_sim = (sim_level is None or sim_level.strip() == '')
|
|
106
|
-
if no_sim and (not _host_distribute()) and context.get_context("mode") == context.PYNATIVE_MODE and \
|
|
107
|
-
group != GlobalComm.WORLD_COMM_GROUP:
|
|
108
|
-
raise RuntimeError(f"{msg_prefix} 'group' only support 'hccl_world_group' in pynative mode, but got "
|
|
109
|
-
f"'group': {group}. Please start by using mpi-run.")
|
|
110
|
-
|
|
111
|
-
|
|
112
100
|
class AllReduce(Primitive):
|
|
113
101
|
"""
|
|
114
102
|
Reduces tensors across all devices in such a way that all devices will get the same final result,
|
|
@@ -187,7 +175,6 @@ class AllReduce(Primitive):
|
|
|
187
175
|
if not isinstance(self.group, str):
|
|
188
176
|
raise TypeError(f"For '{self.name}', the 'group' must be str, "
|
|
189
177
|
f"but got {type(self.group).__name__}.")
|
|
190
|
-
check_hcom_group_valid(self.group, prim_name=self.name)
|
|
191
178
|
self.op = op
|
|
192
179
|
self.add_prim_attr('group', self.group)
|
|
193
180
|
self.add_prim_attr('fusion', 0)
|
|
@@ -720,7 +707,6 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
720
707
|
"""Initialize Broadcast."""
|
|
721
708
|
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
|
722
709
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
723
|
-
check_hcom_group_valid(group, prim_name=self.name)
|
|
724
710
|
self.add_prim_attr('group', _get_group(group))
|
|
725
711
|
self.add_prim_attr('no_eliminate', True)
|
|
726
712
|
|
|
@@ -1954,7 +1940,7 @@ class BatchISendIRecv(PrimitiveWithInfer):
|
|
|
1954
1940
|
|
|
1955
1941
|
|
|
1956
1942
|
class AlltoAllV(PrimitiveWithInfer):
|
|
1957
|
-
"""
|
|
1943
|
+
r"""
|
|
1958
1944
|
AllToAllV which support uneven scatter and gather compared with AllToAll.
|
|
1959
1945
|
|
|
1960
1946
|
Note:
|
|
@@ -2015,7 +2001,7 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
2015
2001
|
... send_tensor = Tensor([0, 1, 2.])
|
|
2016
2002
|
... send_numel_list = [1, 2]
|
|
2017
2003
|
... recv_numel_list = [1, 2]
|
|
2018
|
-
|
|
2004
|
+
... elif rank == 1:
|
|
2019
2005
|
... send_tensor = Tensor([3, 4, 5.])
|
|
2020
2006
|
... send_numel_list = [2, 1]
|
|
2021
2007
|
... recv_numel_list = [2, 1]
|
|
@@ -2027,6 +2013,10 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
2027
2013
|
rank 1:
|
|
2028
2014
|
[1. 2. 5]
|
|
2029
2015
|
|
|
2016
|
+
Tutorial Examples:
|
|
2017
|
+
- `Distributed Set Communication Primitives - AlltoAllV
|
|
2018
|
+
<https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoallv>`_
|
|
2019
|
+
|
|
2030
2020
|
"""
|
|
2031
2021
|
|
|
2032
2022
|
@prim_attr_register
|
|
@@ -2038,6 +2028,94 @@ class AlltoAllV(PrimitiveWithInfer):
|
|
|
2038
2028
|
self.add_prim_attr('block_size', self.block_size)
|
|
2039
2029
|
|
|
2040
2030
|
|
|
2031
|
+
class AlltoAllVC(PrimitiveWithInfer):
|
|
2032
|
+
r"""
|
|
2033
|
+
AllToAllVC passes in the sending and receiving parameters of all ranks through the input parameter
|
|
2034
|
+
`send_count_matrix`. Compared to AllToAllV, AllToAllVC does not require the aggregation of all rank
|
|
2035
|
+
sending and receiving parameters, thus offering superior performance.
|
|
2036
|
+
|
|
2037
|
+
Note:
|
|
2038
|
+
Only one-dimensional input is supported; the input data must be flattened into a one-dimensional
|
|
2039
|
+
array before using this interface.
|
|
2040
|
+
|
|
2041
|
+
Args:
|
|
2042
|
+
group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
|
|
2043
|
+
means ``"hccl_world_group"`` in Ascend.
|
|
2044
|
+
block_size (int, optional): The basic units for splitting and gathering numel by `send_count_matrix`.
|
|
2045
|
+
Default: ``1``.
|
|
2046
|
+
transpose (bool, optional): Indicates whether the `send_count_matrix` needs to undergo a transpose
|
|
2047
|
+
operation, this parameter is used in reverse calculation scenarios. Default: ``False``.
|
|
2048
|
+
|
|
2049
|
+
Inputs:
|
|
2050
|
+
- **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
|
|
2051
|
+
- **send_count_matrix** (Union[list[int], Tensor]) - The sending and receiving parameters of
|
|
2052
|
+
all ranks, :math:`\text{send_count_matrix}[i*\text{rank_size}+j]` represents the amount of data sent by
|
|
2053
|
+
rank i to rank j, and the basic unit is the number of bytes of Tensor's dtype. Among them, `rank_size`
|
|
2054
|
+
indicates the size of the communication group.
|
|
2055
|
+
|
|
2056
|
+
Outputs:
|
|
2057
|
+
Tensor. Flattened and concatenated tensor gather from remote ranks.
|
|
2058
|
+
If gather result is empty, it will return a Tensor with shape `()`, and value has no actual meaning.
|
|
2059
|
+
|
|
2060
|
+
Supported Platforms:
|
|
2061
|
+
``Ascend``
|
|
2062
|
+
|
|
2063
|
+
Examples:
|
|
2064
|
+
.. note::
|
|
2065
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2066
|
+
|
|
2067
|
+
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
|
|
2068
|
+
without any third-party or configuration file dependencies.
|
|
2069
|
+
|
|
2070
|
+
Please see the `msrun start up
|
|
2071
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
2072
|
+
for more details.
|
|
2073
|
+
|
|
2074
|
+
This example should be run with 2 devices.
|
|
2075
|
+
|
|
2076
|
+
>>> from mindspore.ops import AlltoAllVC
|
|
2077
|
+
>>> import mindspore.nn as nn
|
|
2078
|
+
>>> from mindspore.communication import init, get_rank
|
|
2079
|
+
>>> from mindspore import Tensor
|
|
2080
|
+
>>>
|
|
2081
|
+
>>> init()
|
|
2082
|
+
>>> rank = get_rank()
|
|
2083
|
+
>>> class Net(nn.Cell):
|
|
2084
|
+
... def __init__(self):
|
|
2085
|
+
... super(Net, self).__init__()
|
|
2086
|
+
... self.all_to_all_v_c = AlltoAllVC()
|
|
2087
|
+
...
|
|
2088
|
+
... def construct(self, x, send_count_matrix):
|
|
2089
|
+
... return self.all_to_all_v_c(x, send_count_matrix)
|
|
2090
|
+
>>> send_count_matrix = Tensor([[0, 3], [3, 0]])
|
|
2091
|
+
>>> send_tensor = Tensor([0, 1, 2.]) * rank
|
|
2092
|
+
>>> net = Net()
|
|
2093
|
+
>>> output = net(send_tensor, send_count_matrix)
|
|
2094
|
+
>>> print(output)
|
|
2095
|
+
rank 0:
|
|
2096
|
+
[0. 1. 2]
|
|
2097
|
+
rank 1:
|
|
2098
|
+
[0. 0. 0]
|
|
2099
|
+
|
|
2100
|
+
Tutorial Examples:
|
|
2101
|
+
- `Distributed Set Communication Primitives - AlltoAllVC
|
|
2102
|
+
<https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoallvc>`_
|
|
2103
|
+
|
|
2104
|
+
"""
|
|
2105
|
+
|
|
2106
|
+
@prim_attr_register
|
|
2107
|
+
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, block_size=1, transpose=False):
|
|
2108
|
+
self.group = GlobalComm.WORLD_COMM_GROUP if group is None else _get_group(group)
|
|
2109
|
+
self.rank_size = get_group_size(self.group)
|
|
2110
|
+
self.add_prim_attr('rank_size', self.rank_size)
|
|
2111
|
+
self.add_prim_attr('group', self.group)
|
|
2112
|
+
self.rank_id = get_rank(_get_group(self.group))
|
|
2113
|
+
self.add_prim_attr('rank_id', self.rank_id)
|
|
2114
|
+
validator.check_value_type("block_size", block_size, [int], self.name)
|
|
2115
|
+
self.add_prim_attr('block_size', self.block_size)
|
|
2116
|
+
self.add_prim_attr('transpose', self.transpose)
|
|
2117
|
+
|
|
2118
|
+
|
|
2041
2119
|
class AllGatherV(PrimitiveWithInfer):
|
|
2042
2120
|
"""
|
|
2043
2121
|
Gathers uneven tensors from the specified communication group and returns the tensor which is all gathered.
|
|
@@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|
|
18
18
|
import json
|
|
19
19
|
import os
|
|
20
20
|
import re
|
|
21
|
+
import sys
|
|
21
22
|
import ast
|
|
22
23
|
import hashlib
|
|
23
24
|
import stat
|
|
@@ -26,6 +27,7 @@ import inspect
|
|
|
26
27
|
import importlib
|
|
27
28
|
import platform
|
|
28
29
|
import subprocess
|
|
30
|
+
import shutil
|
|
29
31
|
import numpy as np
|
|
30
32
|
import mindspore as ms
|
|
31
33
|
from mindspore._c_expression import Oplib, typing
|
|
@@ -37,6 +39,7 @@ from mindspore.ops import DataType
|
|
|
37
39
|
from mindspore import log as logger
|
|
38
40
|
from mindspore import ops
|
|
39
41
|
from mindspore.communication.management import get_rank, GlobalComm
|
|
42
|
+
from mindspore import _checkparam as validator
|
|
40
43
|
from ._ms_kernel import determine_variable_usage
|
|
41
44
|
from ._custom_grad import autodiff_bprop
|
|
42
45
|
from ._pyfunc_registry import add_pyfunc
|
|
@@ -1075,17 +1078,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1075
1078
|
if isinstance(arg_dtype, mstype.TensorType):
|
|
1076
1079
|
arg_dtype = arg_dtype.element_type()
|
|
1077
1080
|
fake_arg = np.zeros(arg["shape"]).astype(
|
|
1078
|
-
mstype.
|
|
1081
|
+
mstype._dtype_to_nptype(arg_dtype)) # pylint:disable=protected-access
|
|
1079
1082
|
fake_input.append(fake_arg)
|
|
1080
1083
|
|
|
1081
1084
|
fake_output = self.func(*fake_input)
|
|
1082
1085
|
|
|
1083
1086
|
if hasattr(fake_output, 'shape'):
|
|
1084
1087
|
infer_shape = fake_output.shape
|
|
1085
|
-
|
|
1088
|
+
# pylint:disable=protected-access
|
|
1089
|
+
infer_dtype = mstype.TensorType(mstype._pytype_to_dtype(fake_output.dtype))
|
|
1086
1090
|
else:
|
|
1087
1091
|
infer_shape = (1,)
|
|
1088
|
-
infer_dtype = mstype.
|
|
1092
|
+
infer_dtype = mstype._pytype_to_dtype(fake_output.dtype) # pylint:disable=protected-access
|
|
1089
1093
|
|
|
1090
1094
|
infer_value = Tensor(fake_output) if enable_infer_value else None
|
|
1091
1095
|
|
|
@@ -1184,6 +1188,54 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1184
1188
|
return ops.primitive._run_op(self, self.name, args)
|
|
1185
1189
|
|
|
1186
1190
|
|
|
1191
|
+
class _MultiSoProxy:
|
|
1192
|
+
"""
|
|
1193
|
+
A thin wrapper that transparently multiplexes attribute access between a
|
|
1194
|
+
pure-Python fallback module and an optional compiled shared-object (SO)
|
|
1195
|
+
module, honoring MindSpore’s current execution mode (GRAPH vs. PYNATIVE).
|
|
1196
|
+
"""
|
|
1197
|
+
|
|
1198
|
+
def __init__(self, func_module, so_module):
|
|
1199
|
+
"""
|
|
1200
|
+
Args:
|
|
1201
|
+
func_module (module or None): Python module to serve as the fallback implementation source.
|
|
1202
|
+
May be ``None`` if no Python fallback is available.
|
|
1203
|
+
so_module (module): Compiled shared-object module that provides
|
|
1204
|
+
optimized kernels accessible only in ``PYNATIVE_MODE``.
|
|
1205
|
+
"""
|
|
1206
|
+
self.func_module = func_module
|
|
1207
|
+
self.so_module = so_module
|
|
1208
|
+
|
|
1209
|
+
def __getattr__(self, name: str):
|
|
1210
|
+
"""
|
|
1211
|
+
Intercepts every attribute lookup and resolves it against the wrapped
|
|
1212
|
+
modules according to the documented precedence rules.
|
|
1213
|
+
|
|
1214
|
+
Args:
|
|
1215
|
+
name (str): Name of the custom operation being requested.
|
|
1216
|
+
|
|
1217
|
+
Returns:
|
|
1218
|
+
Any: The requested callable or attribute from either ``func_module`` or ``so_module``.
|
|
1219
|
+
|
|
1220
|
+
Raises:
|
|
1221
|
+
AttributeError: If the attribute is not found in any applicable module or
|
|
1222
|
+
is incompatible with the current execution mode.
|
|
1223
|
+
"""
|
|
1224
|
+
if self.func_module is not None and hasattr(self.func_module, name):
|
|
1225
|
+
return getattr(self.func_module, name)
|
|
1226
|
+
if context.get_context("mode") == ms.PYNATIVE_MODE:
|
|
1227
|
+
if hasattr(self.so_module, name):
|
|
1228
|
+
return getattr(self.so_module, name)
|
|
1229
|
+
raise AttributeError(
|
|
1230
|
+
f"Custom op '{name}' is neither in func_module nor in so_module."
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
raise AttributeError(
|
|
1234
|
+
f"Custom op '{name}' does not support GRAPH mode. "
|
|
1235
|
+
f"Please register it for GRAPH mode or switch to PYNATIVE mode."
|
|
1236
|
+
)
|
|
1237
|
+
|
|
1238
|
+
|
|
1187
1239
|
class CustomOpBuilder:
|
|
1188
1240
|
r"""
|
|
1189
1241
|
CustomOpBuilder is used to initialize and configure custom operators for MindSpore.
|
|
@@ -1199,10 +1251,11 @@ class CustomOpBuilder:
|
|
|
1199
1251
|
|
|
1200
1252
|
Args:
|
|
1201
1253
|
name (str): The unique name of the custom operator module, used to identify the operator.
|
|
1202
|
-
sources (Union[str,
|
|
1203
|
-
|
|
1254
|
+
sources (Union[list[str], tuple[str], str]): The source file(s) of the custom operator. It can be a single
|
|
1255
|
+
file path or a list of file paths.
|
|
1204
1256
|
backend (str, optional): The target backend for the operator, such as "CPU" or "Ascend". Default: ``None``.
|
|
1205
|
-
include_paths (list[str], optional): Additionally included paths needed during
|
|
1257
|
+
include_paths (Union[list[str], tuple[str], str], optional): Additionally included paths needed during
|
|
1258
|
+
compilation. Default: ``None``.
|
|
1206
1259
|
cflags (str, optional): Extra C++ compiler flags to be used during compilation. Default: ``None``.
|
|
1207
1260
|
ldflags (str, optional): Extra linker flags to be used during linking. Default: ``None``.
|
|
1208
1261
|
kwargs (dict, optional): Additional keyword arguments for future extensions or specific custom requirements.
|
|
@@ -1216,6 +1269,17 @@ class CustomOpBuilder:
|
|
|
1216
1269
|
- enable_atb (bool, optional): Whether to call ATB (Ascend Transformer Boost) operator. If set to ``True``,
|
|
1217
1270
|
the `backend` must be ``Ascend`` or left empty. Default: ``False``.
|
|
1218
1271
|
|
|
1272
|
+
- enable_asdsip (bool, optional): Whether to call ASDSIP (Ascend SiP Boost) operator. If set to ``True``,
|
|
1273
|
+
the `backend` must be ``Ascend`` or left empty. Default: ``False``.
|
|
1274
|
+
|
|
1275
|
+
- op_def (Union[list[str], tuple[str], str], optional): Path(s) to the operator definition
|
|
1276
|
+
file(s) (YAML format). When using custom operators in graph mode, this parameter is mandatory.
|
|
1277
|
+
It can be a single file path string or a list of file path strings. Default: ``None``.
|
|
1278
|
+
|
|
1279
|
+
- op_doc (Union[list[str], tuple[str], str], optional): Path(s) to the operator documentation
|
|
1280
|
+
file(s) (YAML format). This parameter is optional and used to provide additional documentation
|
|
1281
|
+
for the operator. It can be a single file path string or a list of file path strings. Default: ``None``.
|
|
1282
|
+
|
|
1219
1283
|
.. note::
|
|
1220
1284
|
- If the `backend` argument is provided, additional default flags will be automatically added to
|
|
1221
1285
|
the compilation and linking steps to support the operator's target backend. The default options
|
|
@@ -1238,20 +1302,20 @@ class CustomOpBuilder:
|
|
|
1238
1302
|
_loaded_ops = {}
|
|
1239
1303
|
|
|
1240
1304
|
def __init__(self, name, sources, backend=None, include_paths=None, cflags=None, ldflags=None, **kwargs):
|
|
1241
|
-
self.name
|
|
1242
|
-
|
|
1243
|
-
self.backend = backend
|
|
1244
|
-
self.include_paths = include_paths
|
|
1245
|
-
self.cflags = cflags
|
|
1246
|
-
self.ldflags = ldflags
|
|
1247
|
-
self.build_dir = kwargs.get("build_dir")
|
|
1248
|
-
self.enable_atb = kwargs.get("enable_atb", False)
|
|
1305
|
+
self._check_and_get_args(name, sources, backend, include_paths, cflags, ldflags, **kwargs)
|
|
1306
|
+
|
|
1249
1307
|
self._ms_path = os.path.dirname(os.path.abspath(ms.__file__))
|
|
1308
|
+
self.auto_generate = self.name + "_auto_generate"
|
|
1250
1309
|
if self.enable_atb:
|
|
1251
1310
|
if backend is not None and backend != "Ascend":
|
|
1252
1311
|
raise ValueError("For 'CustomOpBuilder', when 'enable_atb' is set to True, the 'backend' must be "
|
|
1253
1312
|
f"'Ascend' (or left implicit), but got '{backend}'")
|
|
1254
1313
|
self.backend = "Ascend"
|
|
1314
|
+
if self.enable_asdsip:
|
|
1315
|
+
if backend is not None and backend != "Ascend":
|
|
1316
|
+
raise ValueError("For 'CustomOpBuilder', when 'enable_asdsip' is set to True, the 'backend' must be "
|
|
1317
|
+
f"'Ascend' (or left implicit), but got '{backend}'")
|
|
1318
|
+
self.backend = "Ascend"
|
|
1255
1319
|
if self.backend == "Ascend":
|
|
1256
1320
|
ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
|
|
1257
1321
|
if not ascend_opp_path:
|
|
@@ -1263,6 +1327,115 @@ class CustomOpBuilder:
|
|
|
1263
1327
|
if not self.atb_home_path:
|
|
1264
1328
|
raise ValueError("Environment variable 'ATB_HOME_PATH' must be set when 'enable_atb' is True.")
|
|
1265
1329
|
|
|
1330
|
+
def _check_and_get_args(self, name, sources, backend=None, include_paths=None,
|
|
1331
|
+
cflags=None, ldflags=None, **kwargs):
|
|
1332
|
+
"""
|
|
1333
|
+
Validate and normalize all arguments to meet custom-op build requirements.
|
|
1334
|
+
"""
|
|
1335
|
+
|
|
1336
|
+
def _check_str_or_list_str(key, val):
|
|
1337
|
+
if val is None:
|
|
1338
|
+
return val
|
|
1339
|
+
if isinstance(val, str):
|
|
1340
|
+
val = [val]
|
|
1341
|
+
val = validator.check_value_type(key, val, [list, tuple])
|
|
1342
|
+
val = list(val)
|
|
1343
|
+
validator.check_element_type_of_iterable(key, val, [str])
|
|
1344
|
+
return val
|
|
1345
|
+
|
|
1346
|
+
self.name = validator.check_value_type("name", name, [str])
|
|
1347
|
+
self.source = _check_str_or_list_str("sources", sources)
|
|
1348
|
+
self.backend = validator.check_value_type("backend", backend, [str, type(None)])
|
|
1349
|
+
if self.backend is not None and self.backend not in {"CPU", "Ascend"}:
|
|
1350
|
+
raise ValueError(
|
|
1351
|
+
f"For 'backend', only 'CPU' or 'Ascend' are allowed, but got '{self.backend}'.")
|
|
1352
|
+
|
|
1353
|
+
self.include_paths = _check_str_or_list_str("include_paths", include_paths)
|
|
1354
|
+
|
|
1355
|
+
self.cflags = validator.check_value_type("cflags", cflags, [str, type(None)])
|
|
1356
|
+
self.ldflags = validator.check_value_type("ldflags", ldflags, [str, type(None)])
|
|
1357
|
+
|
|
1358
|
+
self.build_dir = validator.check_value_type("build_dir",
|
|
1359
|
+
kwargs.get("build_dir"),
|
|
1360
|
+
[str, type(None)])
|
|
1361
|
+
|
|
1362
|
+
self.debug_mode = validator.check_bool(kwargs.get("debug_mode", False), "debug_mode")
|
|
1363
|
+
self.enable_asdsip = validator.check_bool(kwargs.get("enable_asdsip", False), "enable_asdsip")
|
|
1364
|
+
self.yaml = _check_str_or_list_str("op_def", kwargs.get("op_def"))
|
|
1365
|
+
self.doc = _check_str_or_list_str("op_doc", kwargs.get("op_doc"))
|
|
1366
|
+
|
|
1367
|
+
self.enable_atb = validator.check_bool(kwargs.get("enable_atb", False))
|
|
1368
|
+
|
|
1369
|
+
def _generate_custom_op_def(self, module: str, input_path: str, doc_path: str, output_path: str) -> None:
|
|
1370
|
+
"""Call gen_custom_ops.py to generate custom operator definition"""
|
|
1371
|
+
file_path = os.path.join(self._ms_path, "ops_generate/gen_custom_ops.py")
|
|
1372
|
+
cmd = [
|
|
1373
|
+
sys.executable,
|
|
1374
|
+
file_path,
|
|
1375
|
+
"-i", input_path,
|
|
1376
|
+
"-o", output_path,
|
|
1377
|
+
"-m", module,
|
|
1378
|
+
"-d", doc_path
|
|
1379
|
+
]
|
|
1380
|
+
|
|
1381
|
+
try:
|
|
1382
|
+
subprocess.run(
|
|
1383
|
+
cmd,
|
|
1384
|
+
check=True,
|
|
1385
|
+
text=True,
|
|
1386
|
+
capture_output=True
|
|
1387
|
+
)
|
|
1388
|
+
except subprocess.CalledProcessError as exc:
|
|
1389
|
+
raise RuntimeError(
|
|
1390
|
+
f"gen_custom_op.py failed with exit code {exc.returncode}.\n"
|
|
1391
|
+
f"stdout: {exc.stdout}\n"
|
|
1392
|
+
f"stderr: {exc.stderr}"
|
|
1393
|
+
) from None
|
|
1394
|
+
|
|
1395
|
+
def _get_op_def(self):
|
|
1396
|
+
"""
|
|
1397
|
+
Generate C++ operator-definition source files from one or more YAML specification files.
|
|
1398
|
+
"""
|
|
1399
|
+
if self.yaml is None:
|
|
1400
|
+
return []
|
|
1401
|
+
|
|
1402
|
+
if self.doc is None:
|
|
1403
|
+
logger.info("Missing required 'doc': no YAML document was provided.")
|
|
1404
|
+
|
|
1405
|
+
build_path = self._get_build_directory()
|
|
1406
|
+
yaml_path = os.path.join(build_path, "yaml")
|
|
1407
|
+
op_def_path = os.path.join(build_path, self.auto_generate)
|
|
1408
|
+
if os.path.exists(op_def_path):
|
|
1409
|
+
shutil.rmtree(op_def_path)
|
|
1410
|
+
os.makedirs(op_def_path, exist_ok=True)
|
|
1411
|
+
|
|
1412
|
+
def copy_files(yaml_files, dest_path):
|
|
1413
|
+
if os.path.exists(dest_path):
|
|
1414
|
+
shutil.rmtree(dest_path)
|
|
1415
|
+
os.makedirs(dest_path, exist_ok=True)
|
|
1416
|
+
for file_path in yaml_files:
|
|
1417
|
+
if not os.path.isfile(file_path):
|
|
1418
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
1419
|
+
|
|
1420
|
+
filename = os.path.basename(file_path)
|
|
1421
|
+
file_ext = os.path.splitext(filename)[1].lower()
|
|
1422
|
+
if file_ext not in ('.yaml', '.yml'):
|
|
1423
|
+
raise ValueError(f"Invalid file extension: {file_ext} for {filename}")
|
|
1424
|
+
|
|
1425
|
+
_dest_path = os.path.join(dest_path, filename)
|
|
1426
|
+
shutil.copy2(file_path, _dest_path)
|
|
1427
|
+
|
|
1428
|
+
yaml_files = [self.yaml] if isinstance(self.yaml, str) else self.yaml
|
|
1429
|
+
copy_files(yaml_files, yaml_path)
|
|
1430
|
+
doc_path = ""
|
|
1431
|
+
if self.doc is not None:
|
|
1432
|
+
doc_path = os.path.join(build_path, "doc")
|
|
1433
|
+
doc_files = [self.doc] if isinstance(self.doc, str) else self.doc
|
|
1434
|
+
copy_files(doc_files, doc_path)
|
|
1435
|
+
|
|
1436
|
+
self._generate_custom_op_def(self.name, yaml_path, doc_path, op_def_path)
|
|
1437
|
+
return [os.path.join(op_def_path, "gen_custom_ops_def.cc")]
|
|
1438
|
+
|
|
1266
1439
|
def get_sources(self):
|
|
1267
1440
|
"""
|
|
1268
1441
|
Get the source files for the custom operator.
|
|
@@ -1270,7 +1443,8 @@ class CustomOpBuilder:
|
|
|
1270
1443
|
Returns:
|
|
1271
1444
|
str or list[str], The source file(s) for the operator.
|
|
1272
1445
|
"""
|
|
1273
|
-
|
|
1446
|
+
self.source = [self.source] if isinstance(self.source, str) else self.source
|
|
1447
|
+
return self.source + self._get_op_def()
|
|
1274
1448
|
|
|
1275
1449
|
def get_include_paths(self):
|
|
1276
1450
|
"""
|
|
@@ -1297,6 +1471,7 @@ class CustomOpBuilder:
|
|
|
1297
1471
|
"""include paths for inner module interface."""
|
|
1298
1472
|
ms_inner_path = os.path.join(self._ms_path, "include", "mindspore")
|
|
1299
1473
|
include_list = []
|
|
1474
|
+
include_list.append(os.path.join(ms_inner_path, "include"))
|
|
1300
1475
|
include_list.append(os.path.join(ms_inner_path, "core", "include"))
|
|
1301
1476
|
include_list.append(os.path.join(ms_inner_path, "core", "mindrt", "include"))
|
|
1302
1477
|
include_list.append(os.path.join(ms_inner_path, "core", "mindrt"))
|
|
@@ -1316,10 +1491,16 @@ class CustomOpBuilder:
|
|
|
1316
1491
|
"""
|
|
1317
1492
|
flags = [f'-DMS_EXTENSION_NAME={self.name}', '-D_GLIBCXX_USE_CXX11_ABI=0', '-DENABLE_FAST_HASH_TABLE=1']
|
|
1318
1493
|
flags += ['-std=c++17', '-fstack-protector-all', '-fPIC', '-pie']
|
|
1494
|
+
if self.debug_mode:
|
|
1495
|
+
flags.append('-g')
|
|
1496
|
+
else:
|
|
1497
|
+
flags.append('-O2')
|
|
1319
1498
|
if self.backend == "Ascend":
|
|
1320
1499
|
flags.append('-DCUSTOM_ASCEND_OP')
|
|
1321
1500
|
if self.enable_atb:
|
|
1322
1501
|
flags.append('-DCUSTOM_ENABLE_ATB')
|
|
1502
|
+
if self.enable_asdsip:
|
|
1503
|
+
flags.append('-DCUSTOM_ENABLE_ASDSIP')
|
|
1323
1504
|
if self.cflags is not None:
|
|
1324
1505
|
flags.append(self.cflags)
|
|
1325
1506
|
return flags
|
|
@@ -1332,24 +1513,31 @@ class CustomOpBuilder:
|
|
|
1332
1513
|
list[str], A list of linker flags.
|
|
1333
1514
|
"""
|
|
1334
1515
|
flags = ['-shared']
|
|
1335
|
-
flags += ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath'
|
|
1516
|
+
flags += ['-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath']
|
|
1517
|
+
if not self.debug_mode:
|
|
1518
|
+
flags.append('-s') # strip
|
|
1336
1519
|
flags += [
|
|
1337
1520
|
f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib'))}",
|
|
1338
1521
|
'-lmindspore_core',
|
|
1339
1522
|
'-lmindspore_ms_backend',
|
|
1340
1523
|
'-lmindspore_pynative',
|
|
1341
|
-
'-
|
|
1524
|
+
'-lmindspore_pyboost'
|
|
1342
1525
|
]
|
|
1343
1526
|
if self.backend == "Ascend":
|
|
1344
|
-
flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin'))}")
|
|
1345
1527
|
flags.append(f"-L{os.path.abspath(os.path.join(self.ascend_cann_path, 'lib64'))}")
|
|
1346
1528
|
flags.append('-lascendcl')
|
|
1529
|
+
plugin_path = os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin'))
|
|
1530
|
+
flags.append(f"-L{plugin_path}")
|
|
1531
|
+
flags.append(f"-L{os.path.join(plugin_path, 'ascend')}")
|
|
1347
1532
|
flags.append('-l:libmindspore_ascend.so.2')
|
|
1533
|
+
flags.append('-lmindspore_extension_ascend_aclnn')
|
|
1348
1534
|
if self.enable_atb:
|
|
1349
|
-
flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin', 'ascend'))}")
|
|
1350
1535
|
flags.append('-lmindspore_extension_ascend_atb')
|
|
1351
1536
|
flags.append(f"-L{os.path.abspath(os.path.join(self.atb_home_path, 'lib'))}")
|
|
1352
1537
|
flags.append('-latb')
|
|
1538
|
+
if self.enable_asdsip:
|
|
1539
|
+
flags.append(f"-L{os.path.abspath(os.path.join(self._ms_path, 'lib', 'plugin', 'ascend'))}")
|
|
1540
|
+
flags.append('-lmindspore_extension_ascend_asdsip')
|
|
1353
1541
|
if self.ldflags is not None:
|
|
1354
1542
|
flags.append(self.ldflags)
|
|
1355
1543
|
return flags
|
|
@@ -1380,15 +1568,42 @@ class CustomOpBuilder:
|
|
|
1380
1568
|
"""
|
|
1381
1569
|
if self.name in CustomOpBuilder._loaded_ops:
|
|
1382
1570
|
return CustomOpBuilder._loaded_ops[self.name]
|
|
1571
|
+
|
|
1383
1572
|
module_path = self.build()
|
|
1384
|
-
|
|
1573
|
+
so_module = CustomOpBuilder._import_module(module_path)
|
|
1574
|
+
func_module = None
|
|
1575
|
+
if self.yaml is not None:
|
|
1576
|
+
module_path = os.path.join(self.build_dir, self.auto_generate, "gen_ops_def.py")
|
|
1577
|
+
sys.path.append(os.path.join(self.build_dir, self.auto_generate))
|
|
1578
|
+
sys.path.append(os.path.join(self.build_dir))
|
|
1579
|
+
func_module = self._import_module(module_path, True)
|
|
1580
|
+
mod = _MultiSoProxy(func_module, so_module)
|
|
1385
1581
|
CustomOpBuilder._loaded_ops[self.name] = mod
|
|
1386
1582
|
return mod
|
|
1387
1583
|
|
|
1388
|
-
|
|
1584
|
+
@staticmethod
|
|
1585
|
+
def _import_module(module_path, is_yaml_build=False):
|
|
1389
1586
|
"""Import module from library."""
|
|
1390
|
-
|
|
1587
|
+
module_path = os.path.abspath(module_path)
|
|
1588
|
+
module_dir = os.path.dirname(module_path)
|
|
1589
|
+
module_name = os.path.splitext(os.path.basename(module_path))[0]
|
|
1590
|
+
|
|
1591
|
+
if is_yaml_build:
|
|
1592
|
+
package_name = os.path.basename(module_dir)
|
|
1593
|
+
if module_dir not in sys.path:
|
|
1594
|
+
sys.path.append(module_dir)
|
|
1595
|
+
|
|
1596
|
+
if package_name not in sys.modules:
|
|
1597
|
+
pkg_spec = importlib.machinery.ModuleSpec(package_name, None, is_package=True)
|
|
1598
|
+
pkg = importlib.util.module_from_spec(pkg_spec)
|
|
1599
|
+
pkg.__path__ = [module_dir]
|
|
1600
|
+
sys.modules[package_name] = pkg
|
|
1601
|
+
|
|
1602
|
+
module_name = f"{package_name}.{module_name}"
|
|
1603
|
+
|
|
1604
|
+
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
1391
1605
|
module = importlib.util.module_from_spec(spec)
|
|
1606
|
+
sys.modules[module_name] = module
|
|
1392
1607
|
spec.loader.exec_module(module)
|
|
1393
1608
|
return module
|
|
1394
1609
|
|