mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
""" Param and grad buffer, bucket implemenatrion. """
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["Bucket", "FlattenGradBuffer"]
|
|
19
|
+
|
|
20
|
+
from enum import Enum
|
|
21
|
+
import numpy as np
|
|
22
|
+
from mindspore import mint, Tensor
|
|
23
|
+
from mindspore.common.initializer import Zero
|
|
24
|
+
from mindspore.communication.management import get_group_size
|
|
25
|
+
import mindspore.communication.comm_func as comm_func
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BufferType(Enum):
|
|
29
|
+
PARAM = 0
|
|
30
|
+
GRAD = 1
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MEM_ALIGN_SIZE = 512
|
|
34
|
+
ALIGN_BYTES = 32
|
|
35
|
+
MIN_BUCKET_SIZE = int(1 * 1024 * 1024)
|
|
36
|
+
DEFAULT_BUCKET_SIZE = int(25 * 1024 * 1024)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Bucket:
|
|
40
|
+
"""
|
|
41
|
+
Bucket to track a subset of parameters and gradients in the buffer. Bucket records the parameters
|
|
42
|
+
whose gradient has already been computed. It also provide functionality to synchronize gradients among
|
|
43
|
+
data parallel group when all parameters' graidents have been computed.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
|
|
47
|
+
params (List(Parameters)): Parameters belongs to this bucket.
|
|
48
|
+
grad_data (Tensor): A section of buffers' gradient data, coressponding to parameters in this bucket.
|
|
49
|
+
offset (int): Start index in the buffer.
|
|
50
|
+
numel_unpadded (int): Number of unpadded elements in bucket.
|
|
51
|
+
data_parallel_group (str): Data parallel group name.
|
|
52
|
+
data_parallel_world_size (int): Data parallel group size.
|
|
53
|
+
gradient_scaling_factor (float): Work with average_in_collective, it is 1.0 when average_in_collective
|
|
54
|
+
true else 1.0/dp
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, average_in_collective, params, grad_data, offset, numel_unpadded, data_parallel_group,
|
|
58
|
+
data_parallel_world_size, gradient_scaling_factor):
|
|
59
|
+
self.average_in_collective = average_in_collective
|
|
60
|
+
self.params_list = params
|
|
61
|
+
self.params = set(params)
|
|
62
|
+
self.params_grad_ready = set()
|
|
63
|
+
self.grad_data = grad_data
|
|
64
|
+
self.grad_data_numel = self.grad_data.numel()
|
|
65
|
+
self.offset = offset
|
|
66
|
+
self.numel_unpadded = numel_unpadded
|
|
67
|
+
self.data_parallel_group = data_parallel_group
|
|
68
|
+
self.data_parallel_world_size = data_parallel_world_size
|
|
69
|
+
self.gradient_scaling_factor = gradient_scaling_factor
|
|
70
|
+
|
|
71
|
+
if self.data_parallel_world_size > 1:
|
|
72
|
+
self.grad_reducer = comm_func.all_reduce
|
|
73
|
+
|
|
74
|
+
self.reset()
|
|
75
|
+
|
|
76
|
+
def inplace_reduce_dp(self, src):
|
|
77
|
+
"""conduct all-reduce/reduce-scatter on src tensor and inplace update result into target."""
|
|
78
|
+
self.communication_result, self.communication_handle = self.grad_reducer(
|
|
79
|
+
src, "sum", self.data_parallel_group, async_op=True
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def reset(self):
|
|
83
|
+
"""reset bucket for the next iteration."""
|
|
84
|
+
self.params_grad_ready = set()
|
|
85
|
+
self.is_reduce_issued = False
|
|
86
|
+
self.communication_handle = None
|
|
87
|
+
self.communication_result = None
|
|
88
|
+
|
|
89
|
+
def issue_grad_reduce(self):
|
|
90
|
+
"""issue grad reduce for the local grad data view."""
|
|
91
|
+
if self.is_reduce_issued:
|
|
92
|
+
raise RuntimeError("The bucket reduce is already issued")
|
|
93
|
+
|
|
94
|
+
if self.gradient_scaling_factor != 1.0:
|
|
95
|
+
self.grad_data.copy_(mint.mul(self.grad_data, self.gradient_scaling_factor))
|
|
96
|
+
|
|
97
|
+
if self.data_parallel_world_size > 1:
|
|
98
|
+
self.inplace_reduce_dp(self.grad_data)
|
|
99
|
+
|
|
100
|
+
self.is_reduce_issued = True
|
|
101
|
+
|
|
102
|
+
def final_grad_reduce(self):
|
|
103
|
+
"""finalize grad reduce for the local grad data view."""
|
|
104
|
+
start_idx = 0
|
|
105
|
+
end_idx = self.grad_data_numel
|
|
106
|
+
target = self.grad_data[start_idx:end_idx]
|
|
107
|
+
|
|
108
|
+
if not self.is_reduce_issued:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
f"The bucket reduce has not been issued "
|
|
111
|
+
f"with only {len(self.params_grad_ready)}/{len(self.params)} params ready"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self.data_parallel_world_size > 1:
|
|
115
|
+
self.communication_handle.wait()
|
|
116
|
+
target.copy_(self.communication_result)
|
|
117
|
+
self.communication_result = None
|
|
118
|
+
if self.average_in_collective:
|
|
119
|
+
target.copy_(mint.div(target, self.data_parallel_world_size))
|
|
120
|
+
|
|
121
|
+
def register_grad_ready(self, param):
|
|
122
|
+
"""register grad ready and issue bucket grad reduce when the bucket is ready."""
|
|
123
|
+
if param not in self.params:
|
|
124
|
+
raise ValueError("The param to be registered is not in the bucket")
|
|
125
|
+
|
|
126
|
+
if param in self.params_grad_ready:
|
|
127
|
+
raise ValueError(f"The param {param} is already registered")
|
|
128
|
+
|
|
129
|
+
self.params_grad_ready.add(param)
|
|
130
|
+
if len(self.params_grad_ready) == len(self.params):
|
|
131
|
+
self.issue_grad_reduce()
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
def __repr__(self):
|
|
137
|
+
return f"Bucket (offset={self.offset}, param_lens={len(self.params)})"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class FlattenGradBuffer:
|
|
141
|
+
"""
|
|
142
|
+
Allocate contiguous memory buffer for given parameters and corresponding gradients. Breaking
|
|
143
|
+
up parameters and gradients buffer into small buckets, which is the unit for all-reduce/reduce-scatter
|
|
144
|
+
communication during back-propagation.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
|
|
148
|
+
param_dtype (mindspore.dtype): The parameters' datatype.
|
|
149
|
+
grad_dtype (mindspore.dtype): The gradients' datatype.
|
|
150
|
+
params (List(Parameters)): Parameters belongs to this buffer.
|
|
151
|
+
data_parallel_group (str): Data parallel group name.
|
|
152
|
+
bucket_size (int): Bucket size threshold used to partition bucekts.
|
|
153
|
+
gradient_scaling_factor (float):
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, average_in_collective, param_dtype, grad_dtype, params, data_parallel_group,
|
|
157
|
+
bucket_size, gradient_scaling_factor, ddp_handle):
|
|
158
|
+
super(FlattenGradBuffer, self).__init__()
|
|
159
|
+
self.param_dtype = param_dtype
|
|
160
|
+
self.grad_dtype = grad_dtype
|
|
161
|
+
self.data_parallel_group = data_parallel_group
|
|
162
|
+
self.data_parallel_world_size = get_group_size(group=self.data_parallel_group)
|
|
163
|
+
self.gradient_scaling_factor = gradient_scaling_factor
|
|
164
|
+
self.average_in_collective = average_in_collective
|
|
165
|
+
|
|
166
|
+
self.buckets = []
|
|
167
|
+
self.param_index_map = {}
|
|
168
|
+
self.param_to_bucket = {}
|
|
169
|
+
self.sync_enabled = True
|
|
170
|
+
self.issued = 0
|
|
171
|
+
self.ddp_handle = ddp_handle
|
|
172
|
+
|
|
173
|
+
buckets_metadata = self.calc_partition_metadata(bucket_size, params)
|
|
174
|
+
self.instantiate_buckets(buckets_metadata, params)
|
|
175
|
+
|
|
176
|
+
def calc_partition_metadata(self, bucket_size, params):
|
|
177
|
+
"""calc bucket partition metadata"""
|
|
178
|
+
# helper func
|
|
179
|
+
def _need_new_bucket(bucket_numel, bucket_id):
|
|
180
|
+
target_bucket_size = bucket_size
|
|
181
|
+
if bucket_id == 0 and bucket_size == DEFAULT_BUCKET_SIZE:
|
|
182
|
+
target_bucket_size = MIN_BUCKET_SIZE
|
|
183
|
+
return (
|
|
184
|
+
bucket_size is not None
|
|
185
|
+
and bucket_numel != 0
|
|
186
|
+
and bucket_numel >= target_bucket_size
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def _build_bucket():
|
|
190
|
+
nonlocal buckets_metadata, bucket_start_index, bucket_params, bucket_id
|
|
191
|
+
bucket_end_index = data_start_index
|
|
192
|
+
buckets_metadata.append(
|
|
193
|
+
(bucket_start_index, bucket_end_index, bucket_params)
|
|
194
|
+
)
|
|
195
|
+
bucket_start_index = bucket_end_index
|
|
196
|
+
bucket_id = bucket_id + 1
|
|
197
|
+
bucket_params = []
|
|
198
|
+
|
|
199
|
+
param_data_list = []
|
|
200
|
+
buckets_metadata = []
|
|
201
|
+
data_start_index = 0
|
|
202
|
+
data_end_index = 0
|
|
203
|
+
bucket_id = 0
|
|
204
|
+
bucket_start_index = 0
|
|
205
|
+
bucket_params = []
|
|
206
|
+
for param in params[::]: # traverse from the beginning
|
|
207
|
+
last_bucket_numel = data_start_index - bucket_start_index
|
|
208
|
+
if _need_new_bucket(last_bucket_numel, bucket_id):
|
|
209
|
+
_build_bucket()
|
|
210
|
+
data_end_index = data_start_index + param.numel()
|
|
211
|
+
bucket_params.append(param)
|
|
212
|
+
param_data_list.append(param)
|
|
213
|
+
self.param_index_map[param] = (data_start_index, data_end_index, bucket_id)
|
|
214
|
+
data_start_index = data_end_index
|
|
215
|
+
|
|
216
|
+
# add bucket for the last few params which do not reach the bucket_size threshold
|
|
217
|
+
if data_start_index - bucket_start_index > 0:
|
|
218
|
+
bucket_end_index = data_start_index
|
|
219
|
+
buckets_metadata.append(
|
|
220
|
+
(bucket_start_index, bucket_end_index, bucket_params)
|
|
221
|
+
)
|
|
222
|
+
data_start_index = bucket_end_index
|
|
223
|
+
|
|
224
|
+
# allocate contiguous memory for parameters and gradients
|
|
225
|
+
self.numel = data_start_index
|
|
226
|
+
self.grad_data = Tensor(shape=(self.numel), dtype=self.grad_dtype, init=Zero())
|
|
227
|
+
self.grad_data.init_data()
|
|
228
|
+
self.numel_unpadded = 0
|
|
229
|
+
return buckets_metadata
|
|
230
|
+
|
|
231
|
+
def instantiate_buckets(self, buckets_metadata, params):
|
|
232
|
+
"""build bucket instance according to partition metadata"""
|
|
233
|
+
for bucket_start_index, bucket_end_index, bucket_params in buckets_metadata:
|
|
234
|
+
local_grad_data = self.grad_data[bucket_start_index:bucket_end_index]
|
|
235
|
+
self.numel_unpadded += bucket_end_index - bucket_start_index
|
|
236
|
+
bucket = Bucket(
|
|
237
|
+
average_in_collective=self.average_in_collective,
|
|
238
|
+
params=bucket_params,
|
|
239
|
+
grad_data=local_grad_data,
|
|
240
|
+
offset=bucket_start_index,
|
|
241
|
+
numel_unpadded=bucket_end_index - bucket_start_index,
|
|
242
|
+
data_parallel_group=self.data_parallel_group,
|
|
243
|
+
data_parallel_world_size=self.data_parallel_world_size,
|
|
244
|
+
gradient_scaling_factor=self.gradient_scaling_factor,
|
|
245
|
+
)
|
|
246
|
+
self.buckets.append(bucket)
|
|
247
|
+
for param in bucket_params:
|
|
248
|
+
self.param_to_bucket[param] = bucket
|
|
249
|
+
|
|
250
|
+
for param in params:
|
|
251
|
+
data_start_index, _, _ = self.param_index_map[param]
|
|
252
|
+
param.grad = self._get_buffer_slice(
|
|
253
|
+
param.shape, data_start_index, BufferType.GRAD
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def _get_buffer_slice(self, shape, start_index, buffer_type):
|
|
257
|
+
"""get the buffer view with the same shape"""
|
|
258
|
+
end_index = start_index + int(np.prod(shape))
|
|
259
|
+
if start_index < 0 or end_index > self.numel:
|
|
260
|
+
raise ValueError("index out of range")
|
|
261
|
+
if buffer_type == BufferType.GRAD:
|
|
262
|
+
buffer_tensor = self.grad_data[start_index:end_index]
|
|
263
|
+
else:
|
|
264
|
+
raise TypeError("Invalid buffer type for _get_buffer_slice.")
|
|
265
|
+
buffer_tensor = buffer_tensor.view(shape)
|
|
266
|
+
return buffer_tensor
|
|
267
|
+
|
|
268
|
+
def reset(self):
|
|
269
|
+
"""reset buffer for the next iteration."""
|
|
270
|
+
self.grad_data.zero_()
|
|
271
|
+
for bucket in self.buckets:
|
|
272
|
+
bucket.reset()
|
|
273
|
+
self.sync_enabled = True
|
|
274
|
+
|
|
275
|
+
def final_grad_reduce(self):
|
|
276
|
+
"""finalize grad reduce for each bucket"""
|
|
277
|
+
for bucket in self.buckets:
|
|
278
|
+
bucket.final_grad_reduce()
|
|
279
|
+
|
|
280
|
+
def register_grad_ready(self, param):
|
|
281
|
+
"""register ready grad in its buckets"""
|
|
282
|
+
if self.sync_enabled:
|
|
283
|
+
bucket = self.param_to_bucket[param]
|
|
284
|
+
if bucket.register_grad_ready(param):
|
|
285
|
+
self.issued += 1
|
|
286
|
+
if self.issued == len(self.buckets):
|
|
287
|
+
self.ddp_handle.buffer_issued += 1
|
|
288
|
+
if self.ddp_handle.buffer_issued == len(self.ddp_handle.buffers):
|
|
289
|
+
self.ddp_handle.final_grad_reduce()
|
|
290
|
+
|
|
291
|
+
def __repr__(self):
|
|
292
|
+
param_index_with_name = {
|
|
293
|
+
param.name: index for (param, index) in self.param_index_map.items()
|
|
294
|
+
}
|
|
295
|
+
return f"Buffer has buckets: \n {self.buckets} \n and param_index_map: \n {param_index_with_name}"
|
|
@@ -42,11 +42,12 @@ def reshard(tensor, layout):
|
|
|
42
42
|
can check :class:`mindspore.parallel.Layout` for reference.
|
|
43
43
|
|
|
44
44
|
Note:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
45
|
+
In the Graph mode, this function can set the sharding propagation strategy of a tensor.
|
|
46
|
+
For those tensor do not manually be set, their strategies are decided by the sharding
|
|
47
|
+
strategy propagation algorithm automatically.
|
|
48
|
+
|
|
49
|
+
.. warning::
|
|
50
|
+
The method is currently not supported in PyNative mode.
|
|
50
51
|
|
|
51
52
|
Args:
|
|
52
53
|
tensor (Tensor): The tensor to be set the sharding strategy.
|
|
@@ -28,7 +28,8 @@ from mindspore import log as logger
|
|
|
28
28
|
|
|
29
29
|
class PipelineCell(Cell):
|
|
30
30
|
"""
|
|
31
|
-
Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training
|
|
31
|
+
Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training,
|
|
32
|
+
and specify the segment info.
|
|
32
33
|
|
|
33
34
|
Note:
|
|
34
35
|
micro_size must be greater or equal to pipeline stages.
|
|
@@ -37,6 +38,8 @@ class PipelineCell(Cell):
|
|
|
37
38
|
network (Cell): The target network to wrap.
|
|
38
39
|
micro_size (int): MicroBatch size.
|
|
39
40
|
stage_config (dict, optional): The stage configuration for each cell's execution in pipeline parallel.
|
|
41
|
+
segment_config (dict, optional): The segment configuration for each cell's execution in pipeline parallel.
|
|
42
|
+
Default ``None``.
|
|
40
43
|
|
|
41
44
|
Supported Platforms:
|
|
42
45
|
``Ascend``
|
|
@@ -48,7 +51,7 @@ class PipelineCell(Cell):
|
|
|
48
51
|
>>> net = LeNet5()
|
|
49
52
|
>>> net = nn.PipelineCell(net, 4, stage_config={"cell_name_0": 0, "cell_name_1": 1})
|
|
50
53
|
"""
|
|
51
|
-
def __init__(self, network, micro_size, stage_config=None):
|
|
54
|
+
def __init__(self, network, micro_size, stage_config=None, segment_config=None):
|
|
52
55
|
super(PipelineCell, self).__init__(auto_prefix=False)
|
|
53
56
|
self.network = network
|
|
54
57
|
self.micro_inputs = nn.CellList()
|
|
@@ -104,6 +107,37 @@ class PipelineCell(Cell):
|
|
|
104
107
|
logger.warning(cell_name)
|
|
105
108
|
raise KeyError("For 'PipelineCell', the argument 'stage_config' : {} is not "
|
|
106
109
|
"found in 'network' : {}".format(config_dict, network))
|
|
110
|
+
if segment_config is None:
|
|
111
|
+
return
|
|
112
|
+
self._config_segment(segment_config)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _config_segment(self, segment_config):
|
|
116
|
+
"""
|
|
117
|
+
Config segment num for cell.
|
|
118
|
+
"""
|
|
119
|
+
config_dict = segment_config.copy()
|
|
120
|
+
|
|
121
|
+
for cell_name, cell in self.network.cells_and_names():
|
|
122
|
+
if cell_name in segment_config:
|
|
123
|
+
setattr(cell, "pipeline_segment", segment_config[cell_name])
|
|
124
|
+
del config_dict[cell_name]
|
|
125
|
+
if str(self.network) in segment_config:
|
|
126
|
+
setattr(self.network, "pipeline_segment", segment_config[str(self.network)])
|
|
127
|
+
del config_dict[str(self.network)]
|
|
128
|
+
# if there are any config elements left, print them
|
|
129
|
+
if config_dict:
|
|
130
|
+
for config_cell_name, config_segment_num in config_dict.items():
|
|
131
|
+
logger.error("pipeline_cell segment_config set pipeline_segment fail!")
|
|
132
|
+
logger.warning("config cell name:" + str(config_cell_name) +
|
|
133
|
+
" config segment num:" + str(config_segment_num))
|
|
134
|
+
logger.warning("network:" + str(self.network))
|
|
135
|
+
logger.warning("cell name available:")
|
|
136
|
+
for cell_name, _ in self.network.cells_and_names():
|
|
137
|
+
logger.warning(cell_name)
|
|
138
|
+
raise KeyError("For 'PipelineCell', the argument 'segment_config' : {} is not "
|
|
139
|
+
"found in 'network' : {}".format(config_dict, self.network))
|
|
140
|
+
|
|
107
141
|
|
|
108
142
|
def construct(self, *args, **kwargs):
|
|
109
143
|
ret = None
|
|
@@ -119,7 +153,8 @@ class PipelineCell(Cell):
|
|
|
119
153
|
|
|
120
154
|
class Pipeline(PipelineCell):
|
|
121
155
|
"""
|
|
122
|
-
Specify the number of micro_batch for pipeline parallelism and the division rules for stage
|
|
156
|
+
Specify the number of micro_batch for pipeline parallelism and the division rules for stage,
|
|
157
|
+
and specify the segment info.
|
|
123
158
|
|
|
124
159
|
Note:
|
|
125
160
|
micro_size must be greater or equal to pipeline stages.
|
|
@@ -128,6 +163,8 @@ class Pipeline(PipelineCell):
|
|
|
128
163
|
network (Cell): The target network to wrap.
|
|
129
164
|
micro_size (int): MicroBatch size.
|
|
130
165
|
stage_config (dict, optional): Stage configuration for cell's execution in pipeline parallel. Default ``None``.
|
|
166
|
+
segment_config (dict, optional): The segment configuration for each cell's execution in pipeline parallel.
|
|
167
|
+
Default ``None``.
|
|
131
168
|
|
|
132
169
|
Raises:
|
|
133
170
|
TypeError: The type of `net` is not cell.
|
|
@@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
__all__ = ['PipelineGradReducer']
|
|
19
19
|
|
|
20
|
-
from mindspore import context
|
|
21
20
|
from mindspore.nn.cell import Cell
|
|
22
21
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
23
22
|
import mindspore.common.dtype as mstype
|
|
@@ -140,7 +139,6 @@ class PipelineGradReducer(Cell):
|
|
|
140
139
|
"""
|
|
141
140
|
def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
|
|
142
141
|
super(PipelineGradReducer, self).__init__(auto_prefix=False)
|
|
143
|
-
self._check_mode()
|
|
144
142
|
self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
|
|
145
143
|
self.grad_reducer = Identity()
|
|
146
144
|
self.degree = Tensor(1, mstype.float32)
|
|
@@ -162,9 +160,3 @@ class PipelineGradReducer(Cell):
|
|
|
162
160
|
accu_grads = self.grad_reducer(self.accu_grads)
|
|
163
161
|
new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
|
|
164
162
|
return new_grads
|
|
165
|
-
|
|
166
|
-
def _check_mode(self):
|
|
167
|
-
"""check parallel mode"""
|
|
168
|
-
mode = context.get_context('mode')
|
|
169
|
-
if mode != context.GRAPH_MODE:
|
|
170
|
-
raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
|
mindspore/parallel/shard.py
CHANGED
|
@@ -253,13 +253,6 @@ class Shard(Shard_):
|
|
|
253
253
|
"will be overwritten as False.")
|
|
254
254
|
ms.set_algo_parameters(fully_use_devices=False)
|
|
255
255
|
|
|
256
|
-
if ms.context.get_auto_parallel_context("full_batch_is_set") is False and \
|
|
257
|
-
ms.context.get_context("mode") == ms.context.PYNATIVE_MODE:
|
|
258
|
-
logger.warning("When calling the shard interface, "
|
|
259
|
-
"'dataset_strategy' or 'full_batch' is not manually set by the user, "
|
|
260
|
-
"and the 'dataset_strategy' will be set to 'full_batch'.")
|
|
261
|
-
ms.context.set_auto_parallel_context(dataset_strategy="full_batch")
|
|
262
|
-
|
|
263
256
|
if self._is_attrs_has_been_set(fn, in_strategy, out_strategy, device, level):
|
|
264
257
|
return self.shard_fn
|
|
265
258
|
shard_ = Shard()
|
|
@@ -394,11 +387,10 @@ class Shard(Shard_):
|
|
|
394
387
|
f"The tuple strategy for each dimension should be tuple(int).")
|
|
395
388
|
|
|
396
389
|
|
|
397
|
-
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None
|
|
390
|
+
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None):
|
|
398
391
|
"""
|
|
399
392
|
Specify the input and output slicing strategy for a Cell or function.
|
|
400
|
-
In
|
|
401
|
-
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
393
|
+
In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
402
394
|
strategy for others will be set by sharding propagation.
|
|
403
395
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
404
396
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
@@ -410,7 +402,9 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
410
402
|
- It is valid only in semi auto parallel or auto parallel mode.
|
|
411
403
|
In other parallel modes, strategies set here will be ignored.
|
|
412
404
|
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
413
|
-
|
|
405
|
+
|
|
406
|
+
.. warning::
|
|
407
|
+
The method is currently not supported in PyNative mode.
|
|
414
408
|
|
|
415
409
|
Args:
|
|
416
410
|
fn (Union[Cell, Function]): Function to be executed in parallel.
|
|
@@ -432,19 +426,12 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
432
426
|
has been set, the parameter setting will be ignored. Supported
|
|
433
427
|
only when `fn` is a Cell with parameters.
|
|
434
428
|
Default: ``None`` .
|
|
435
|
-
device (str, optional): Select a certain `device` target. It is not in use right now.
|
|
436
|
-
Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
|
|
437
|
-
level (int, optional): Option for parallel strategy infer algorithm, namely the object function,
|
|
438
|
-
maximize computation
|
|
439
|
-
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
440
|
-
use right now. Support [0, 1, 2]. Default: ``0`` .
|
|
441
429
|
|
|
442
430
|
Returns:
|
|
443
431
|
Function, return the function that will be executed under auto parallel process.
|
|
444
432
|
|
|
445
433
|
Raises:
|
|
446
434
|
AssertionError: If parallel mode is not "auto_parallel" nor "semi_auto_parallel".
|
|
447
|
-
AssertionError: If device_target it not "Ascend" or "GPU".
|
|
448
435
|
TypeError: If `in_strategy` is not a tuple.
|
|
449
436
|
TypeError: If `out_strategy` is not a tuple or None.
|
|
450
437
|
TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
|
|
@@ -452,8 +439,6 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
452
439
|
TypeError: If `parameter_plan` is not a dict or None.
|
|
453
440
|
TypeError: If any key in `parameter_plan` is not a str.
|
|
454
441
|
TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.parallel.Layout).
|
|
455
|
-
TypeError: If `device` is not a str.
|
|
456
|
-
TypeError: If `level` is not an integer.
|
|
457
442
|
|
|
458
443
|
Supported Platforms:
|
|
459
444
|
``Ascend``
|
|
@@ -556,4 +541,5 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
556
541
|
if not isinstance(fn, (ms.nn.Cell)):
|
|
557
542
|
logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
|
|
558
543
|
"otherwise, the result may be incorrect.")
|
|
559
|
-
|
|
544
|
+
|
|
545
|
+
return Shard()(fn, in_strategy, out_strategy, parameter_plan)
|