mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""parallel serialization"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore import context
|
|
19
|
+
from mindspore.nn.cell import Cell
|
|
20
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
21
|
+
import mindspore.common.dtype as mstype
|
|
22
|
+
from mindspore.common.sparse_tensor import Tensor
|
|
23
|
+
from mindspore.common.api import jit
|
|
24
|
+
from mindspore.common.parameter import Parameter
|
|
25
|
+
from mindspore.nn.layer import Identity
|
|
26
|
+
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
|
27
|
+
|
|
28
|
+
__all__ = ['PipelineGradReducer']
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
32
|
+
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
|
|
33
|
+
reciprocal = P.Reciprocal()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@grad_scale.register("Tensor", "Tensor", "Tensor")
|
|
37
|
+
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
|
38
|
+
accu_grad = F.depend(accu_grad, grad)
|
|
39
|
+
new_grad = accu_grad * reciprocal(scale)
|
|
40
|
+
accu_grad = F.depend(accu_grad, new_grad)
|
|
41
|
+
zeros = F.tensor_mul(accu_grad, 0.0)
|
|
42
|
+
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
|
|
43
|
+
return new_grad
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
|
47
|
+
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
|
48
|
+
new_grad = grad * reciprocal(scale)
|
|
49
|
+
accu_grad = F.depend(accu_grad, new_grad)
|
|
50
|
+
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
|
|
51
|
+
return new_grad
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class PipelineGradReducer(Cell):
|
|
55
|
+
"""
|
|
56
|
+
Functional training scenarios for gradient statute and accumulation of pipeline parallel.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
parameters (list): the parameters to be updated.
|
|
60
|
+
scale_sense (float, optional): the scale sense of the gradient. Default: 1.0.
|
|
61
|
+
opt_shard(bool, optional): if use parallel optimizer, set opt_shard True. Default: ``None``.
|
|
62
|
+
|
|
63
|
+
Raise:
|
|
64
|
+
RuntimeError: If the mode is not graph mode.
|
|
65
|
+
|
|
66
|
+
Supported Platforms:
|
|
67
|
+
``Ascend`` ``GPU``
|
|
68
|
+
|
|
69
|
+
Examples:
|
|
70
|
+
.. note::
|
|
71
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
72
|
+
|
|
73
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
74
|
+
Please see the `rank table Startup
|
|
75
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
|
|
76
|
+
for more details.
|
|
77
|
+
|
|
78
|
+
This example should be run with multiple devices.
|
|
79
|
+
|
|
80
|
+
>>> import numpy as np
|
|
81
|
+
>>> import mindspore as ms
|
|
82
|
+
>>> from mindspore import nn, ops, Tensor
|
|
83
|
+
>>> from mindspore.communication import init
|
|
84
|
+
>>>
|
|
85
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
86
|
+
>>> ms.reset_auto_parallel_context()
|
|
87
|
+
>>> init()
|
|
88
|
+
>>> ms.set_seed(1)
|
|
89
|
+
>>>
|
|
90
|
+
>>> class Network(nn.Cell):
|
|
91
|
+
... def __init__(self, in_features, out_features, sens=1.0):
|
|
92
|
+
... super().__init__()
|
|
93
|
+
... self.layer1 = nn.Dense(in_features, 16)
|
|
94
|
+
... self.relu1 = nn.ReLU()
|
|
95
|
+
... self.layer2 = nn.Dense(16, 16)
|
|
96
|
+
... self.relu2 = nn.ReLU()
|
|
97
|
+
... self.layer3 = nn.Dense(16, out_features)
|
|
98
|
+
...
|
|
99
|
+
... def construct(self, x):
|
|
100
|
+
... x = self.layer1(x)
|
|
101
|
+
... x = self.relu1(x)
|
|
102
|
+
... x = self.layer2(x)
|
|
103
|
+
... x = self.relu2(x)
|
|
104
|
+
... logits = self.layer3(x)
|
|
105
|
+
... return logits
|
|
106
|
+
>>>
|
|
107
|
+
>>> size, in_features, out_features = 16, 32, 10
|
|
108
|
+
>>> net = Network(in_features, out_features)
|
|
109
|
+
>>> net.layer1.pipeline_stage = 0
|
|
110
|
+
>>> net.relu1.pipeline_stage = 0
|
|
111
|
+
>>> net.layer2.pipeline_stage = 0
|
|
112
|
+
>>> net.relu2.pipeline_stage = 1
|
|
113
|
+
>>> net.layer3.pipeline_stage = 1
|
|
114
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
115
|
+
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
116
|
+
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 2)
|
|
117
|
+
>>> net_with_loss.set_train()
|
|
118
|
+
>>> def forward_fn(inputs, target):
|
|
119
|
+
... loss = net_with_loss(inputs, target)
|
|
120
|
+
... return loss
|
|
121
|
+
>>>
|
|
122
|
+
>>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params())
|
|
123
|
+
>>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters)
|
|
124
|
+
>>>
|
|
125
|
+
>>> @ms.jit
|
|
126
|
+
>>> def train_one_step(inputs, target):
|
|
127
|
+
... loss, grads = grad_fn(inputs, target)
|
|
128
|
+
... grads = pp_grad_reducer(grads)
|
|
129
|
+
... optimizer(grads)
|
|
130
|
+
... return loss, grads
|
|
131
|
+
>>>
|
|
132
|
+
>>> parallel_net = AutoParallel(train_one_step, parallel_mode="semi_auto")
|
|
133
|
+
>>> parallel_net.pipeline(stages=2)
|
|
134
|
+
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
|
|
135
|
+
>>> label = Tensor(np.ones([size, out_features]).astype(np.float32))
|
|
136
|
+
>>> loss, _ = train_one_step(inputs, label)
|
|
137
|
+
>>> print(loss)
|
|
138
|
+
46.36721
|
|
139
|
+
"""
|
|
140
|
+
def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
|
|
141
|
+
super(PipelineGradReducer, self).__init__(auto_prefix=False)
|
|
142
|
+
self._check_mode()
|
|
143
|
+
self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
|
|
144
|
+
self.grad_reducer = Identity()
|
|
145
|
+
self.degree = Tensor(1, mstype.float32)
|
|
146
|
+
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
|
147
|
+
self.hyper_map = C.HyperMap()
|
|
148
|
+
if opt_shard is None:
|
|
149
|
+
self.opt_shard = _get_enable_parallel_optimizer()
|
|
150
|
+
else:
|
|
151
|
+
self.opt_shard = opt_shard
|
|
152
|
+
|
|
153
|
+
@jit
|
|
154
|
+
def construct(self, grads):
|
|
155
|
+
new_grads = None
|
|
156
|
+
if self.opt_shard:
|
|
157
|
+
grads = self.grad_reducer(grads)
|
|
158
|
+
new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree),
|
|
159
|
+
grads, self.accu_grads)
|
|
160
|
+
else:
|
|
161
|
+
accu_grads = self.grad_reducer(self.accu_grads)
|
|
162
|
+
new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
|
|
163
|
+
return new_grads
|
|
164
|
+
|
|
165
|
+
def _check_mode(self):
|
|
166
|
+
"""check parallel mode"""
|
|
167
|
+
mode = context.get_context('mode')
|
|
168
|
+
if mode != context.GRAPH_MODE:
|
|
169
|
+
raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
|
|
@@ -19,7 +19,10 @@ __all__ = ["parameter_broadcast"]
|
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import mindspore as ms
|
|
22
|
-
from mindspore.communication import
|
|
22
|
+
from mindspore.communication import create_group, get_group_size
|
|
23
|
+
from mindspore.parallel._utils import _get_auto_parallel_net, _parallel_mode_map, _check_rank
|
|
24
|
+
# disable pylint too broad Exception
|
|
25
|
+
# pylint: disable=W0212
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
@@ -34,7 +37,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
34
37
|
layout (Dict): Parameter layout dictionary. Come from
|
|
35
38
|
:func:`mindspore.nn.Cell.parameter_layout_dict`
|
|
36
39
|
or read from file(for example: "strategy.ckpt" saved by using the
|
|
37
|
-
`strategy_ckpt_config` parameter of
|
|
40
|
+
`strategy_ckpt_config` parameter of
|
|
41
|
+
:func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file` ).
|
|
38
42
|
The key is param name, the value is the layout of this parameter.
|
|
39
43
|
cur_rank (int, optional): current rank id. Default: ``0``.
|
|
40
44
|
initial_rank (int, optional): Start rank id for each pipeline. Default: ``0``.
|
|
@@ -45,6 +49,9 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
45
49
|
ValueError: Parameter name in `layout` can not be found in
|
|
46
50
|
:func:`mindspore.nn.Cell.parameters_dict`.
|
|
47
51
|
|
|
52
|
+
Supported Platforms:
|
|
53
|
+
``Ascend``
|
|
54
|
+
|
|
48
55
|
Examples:
|
|
49
56
|
>>> import os
|
|
50
57
|
>>> import mindspore as ms
|
|
@@ -53,11 +60,11 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
53
60
|
>>> from mindspore.communication import init
|
|
54
61
|
>>> from mindspore.common.initializer import initializer
|
|
55
62
|
>>> from mindspore.train import Model
|
|
56
|
-
>>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
|
|
57
63
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
64
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
65
|
+
>>> from mindspore.parallel import parameter_broadcast
|
|
58
66
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
59
|
-
>>> ms.
|
|
60
|
-
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
67
|
+
>>> ms.runtime.set_memory(max_size="28GB")
|
|
61
68
|
>>> init()
|
|
62
69
|
>>> ms.set_seed(1)
|
|
63
70
|
>>> class Network(nn.Cell):
|
|
@@ -90,7 +97,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
90
97
|
>>> dataset = create_dataset()
|
|
91
98
|
>>> optim = nn.SGD(net.trainable_params(), 1e-2)
|
|
92
99
|
>>> loss = nn.CrossEntropyLoss()
|
|
93
|
-
>>>
|
|
100
|
+
>>> parallel_net = AutoParallel(net)
|
|
101
|
+
>>> model = Model(parallel_net, loss_fn=loss, optimizer=optim)
|
|
94
102
|
>>> model.train(1, dataset)
|
|
95
103
|
>>> ms.save_checkpoint(net, "./simple.ckpt", False)
|
|
96
104
|
>>> layout = model.train_network.parameter_layout_dict
|
|
@@ -104,17 +112,20 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
104
112
|
... print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
|
|
105
113
|
>>> model.train(1, dataset, callbacks=[LossCallBack()])
|
|
106
114
|
"""
|
|
107
|
-
if not layout:
|
|
115
|
+
if not layout or get_group_size() <= 1:
|
|
108
116
|
return
|
|
109
117
|
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
110
118
|
from mindspore.nn.wrap.cell_wrapper import AllreduceGraph
|
|
111
|
-
origin_parallel_mode =
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
119
|
+
origin_parallel_mode = ""
|
|
120
|
+
pipeline_stages = 1
|
|
121
|
+
parallel_net = _get_auto_parallel_net(net)
|
|
122
|
+
if type(parallel_net).__name__ == 'AutoParallel':
|
|
123
|
+
origin_parallel_mode = _parallel_mode_map(parallel_net._parallel_mode)
|
|
124
|
+
pipeline_stages = parallel_net._pipeline_stages
|
|
125
|
+
else:
|
|
126
|
+
origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
|
|
127
|
+
pipeline_stages = ms.get_auto_parallel_context("pipeline_stages")
|
|
128
|
+
_check_rank(cur_rank, initial_rank, pipeline_stages)
|
|
118
129
|
param_redundancy = get_parameter_redundancy(layout, initial_rank)
|
|
119
130
|
if not param_redundancy:
|
|
120
131
|
return
|
mindspore/parallel/shard.py
CHANGED
|
@@ -15,14 +15,77 @@
|
|
|
15
15
|
"""shard"""
|
|
16
16
|
|
|
17
17
|
import copy
|
|
18
|
+
import numpy as np
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore import log as logger
|
|
20
21
|
from mindspore._c_expression import Shard_
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
class _DistributedTensorInfo:
|
|
25
|
+
"""
|
|
26
|
+
Describe the distributed information of a tensor.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
distributed_info (Union[Layout, DeviceMesh]): The distributed information of a tensor.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
TypeError: If `distributed_info` is not a Layout type.
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
>>> from mindspore import _DistributedTensorInfo, Layout
|
|
36
|
+
>>> layout = Layout((2, 2), ("dp", "mp"))
|
|
37
|
+
>>> src_layout = layout("dp", "mp")
|
|
38
|
+
>>> distributed_info = _DistributedTensorInfo(src_layout)
|
|
39
|
+
>>> print(distributed_info.sharding_strategy)
|
|
40
|
+
[2, 2]
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, distributed_info):
|
|
44
|
+
if isinstance(distributed_info, Layout):
|
|
45
|
+
self._layout = distributed_info
|
|
46
|
+
self._distributed_info = distributed_info
|
|
47
|
+
else:
|
|
48
|
+
raise TypeError(
|
|
49
|
+
f"DistributedTensorInfo only supports Layout or DeviceMesh as input, but got {type(distributed_info)}")
|
|
50
|
+
self._sharding_strategy = None
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def layout(self):
|
|
54
|
+
"""return layout of current tensor"""
|
|
55
|
+
return self._layout
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def distributed_info(self):
|
|
59
|
+
"""return the distributed info, it depends on user's input """
|
|
60
|
+
return self._distributed_info
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def sharding_strategy(self):
|
|
64
|
+
"""return the sharding strategy of current tensor"""
|
|
65
|
+
if self._sharding_strategy is None:
|
|
66
|
+
layout_info = self._layout.to_dict()
|
|
67
|
+
device_matrix = layout_info["device_matrix"]
|
|
68
|
+
tensor_map = layout_info["tensor_map"]
|
|
69
|
+
sharding_strategy = []
|
|
70
|
+
for map_value in tensor_map:
|
|
71
|
+
if isinstance(map_value, (tuple, list)):
|
|
72
|
+
shard_size = 1
|
|
73
|
+
for value in map_value:
|
|
74
|
+
if value != -1:
|
|
75
|
+
shard_size *= device_matrix[len(device_matrix) - value - 1]
|
|
76
|
+
sharding_strategy.append(shard_size)
|
|
77
|
+
else:
|
|
78
|
+
if map_value != -1:
|
|
79
|
+
sharding_strategy.append(device_matrix[len(device_matrix) - map_value - 1])
|
|
80
|
+
else:
|
|
81
|
+
sharding_strategy.append(1)
|
|
82
|
+
self._sharding_strategy = sharding_strategy
|
|
83
|
+
return self._sharding_strategy
|
|
84
|
+
|
|
85
|
+
|
|
23
86
|
class Layout:
|
|
24
87
|
"""
|
|
25
|
-
|
|
88
|
+
Topological abstraction describing cluster devices for tensor slice placement on the cluster.
|
|
26
89
|
|
|
27
90
|
Note:
|
|
28
91
|
- It is valid only in semi auto parallel or auto parallel mode.
|
|
@@ -35,28 +98,36 @@ class Layout:
|
|
|
35
98
|
alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
|
|
36
99
|
When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
|
|
37
100
|
copies on the corresponding partition dimension on a single card.
|
|
101
|
+
rank_list (list, optional): Data is allocated to the device according to rank_list. Default: ``None``.
|
|
102
|
+
|
|
38
103
|
Raises:
|
|
39
104
|
TypeError: `device_matrix` is not a tuple type.
|
|
40
105
|
TypeError: `alias_name` is not a tuple type.
|
|
106
|
+
TypeError: 'rank_list' is not a list type.
|
|
41
107
|
ValueError: `device_matrix` length is not equal to `alias_name` length.
|
|
42
108
|
TypeError: The element of `device_matrix` is not int type.
|
|
43
109
|
TypeError: The element of `alias_name` is not a str type.
|
|
110
|
+
TypeError: The element of `rank_list` is not int type.
|
|
44
111
|
ValueError: The element of `alias_name` is an empty str.
|
|
45
112
|
ValueError: The element of `alias_name` is "None".
|
|
46
113
|
ValueError: `alias_name` contains repeated element.
|
|
47
114
|
|
|
115
|
+
Supported Platforms:
|
|
116
|
+
``Ascend``
|
|
117
|
+
|
|
48
118
|
Examples:
|
|
49
|
-
>>> from mindspore import Layout
|
|
119
|
+
>>> from mindspore.parallel import Layout
|
|
50
120
|
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
|
51
121
|
>>> layout0 = layout("dp", "mp")
|
|
52
122
|
>>> print(layout0.to_dict())
|
|
53
|
-
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False
|
|
123
|
+
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False,
|
|
124
|
+
'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 1, 6, 7]}
|
|
54
125
|
>>> # Total device num is 4, but split the tensor in local device into two copies.
|
|
55
126
|
>>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
|
|
56
127
|
>>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
|
|
57
128
|
"""
|
|
58
129
|
|
|
59
|
-
def __init__(self, device_matrix, alias_name):
|
|
130
|
+
def __init__(self, device_matrix, alias_name, rank_list=None):
|
|
60
131
|
if not isinstance(device_matrix, tuple):
|
|
61
132
|
raise TypeError(f'device_matrix must be tuple type, but got:{type(device_matrix)}')
|
|
62
133
|
if not isinstance(alias_name, tuple):
|
|
@@ -82,6 +153,20 @@ class Layout:
|
|
|
82
153
|
self._device_shape = device_matrix
|
|
83
154
|
self._alias_name = alias_name
|
|
84
155
|
self._tensor_map = None
|
|
156
|
+
self._rank_list = list(range(np.prod(np.array(self._device_shape))))
|
|
157
|
+
if rank_list is not None:
|
|
158
|
+
if not isinstance(rank_list, list):
|
|
159
|
+
raise TypeError(f"The rank_list should be a list, but got {type(rank_list).__name__}.")
|
|
160
|
+
for in_ele in rank_list:
|
|
161
|
+
if not isinstance(in_ele, int):
|
|
162
|
+
raise TypeError(f"The element of rank_list should be int, but got {type(in_ele).__name__}.")
|
|
163
|
+
if len(np.array(rank_list).shape) != 1:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"The rank_list should be a 1-D list, but got {len(np.array(rank_list).shape)}-D list.")
|
|
166
|
+
if len(rank_list) != np.prod(np.array(self._device_shape)):
|
|
167
|
+
raise ValueError(f"The length of rank_list should be equal to the product of device_matrix, "
|
|
168
|
+
f"but got {len(rank_list)} and {np.prod(np.array(self._device_shape))}.")
|
|
169
|
+
self._rank_list = rank_list
|
|
85
170
|
|
|
86
171
|
def __call__(self, *tensor_map):
|
|
87
172
|
self._tensor_map = ()
|
|
@@ -122,8 +207,8 @@ class Layout:
|
|
|
122
207
|
raise ValueError("The tensor_map of layout is None")
|
|
123
208
|
interleaved_parallel = "interleaved_parallel" in self._alias_name
|
|
124
209
|
return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
|
|
125
|
-
"interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name
|
|
126
|
-
|
|
210
|
+
"interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name,
|
|
211
|
+
"rank_list": self._rank_list}
|
|
127
212
|
|
|
128
213
|
|
|
129
214
|
class Shard(Shard_):
|
|
@@ -141,18 +226,6 @@ class Shard(Shard_):
|
|
|
141
226
|
self.level = None
|
|
142
227
|
|
|
143
228
|
def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
144
|
-
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
|
|
145
|
-
if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
|
|
146
|
-
raise AssertionError(
|
|
147
|
-
f"Cell shard only supports auto parallel and semi auto parallel.")
|
|
148
|
-
if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
|
|
149
|
-
raise AssertionError(
|
|
150
|
-
f"'Shard' now only supports 'Ascend' and 'GPU'")
|
|
151
|
-
if parallel_mode == "auto_parallel" and \
|
|
152
|
-
ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
|
153
|
-
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
|
|
154
|
-
f"'parallel_mode' is 'auto_parallel.'")
|
|
155
|
-
|
|
156
229
|
if not isinstance(in_strategy, tuple):
|
|
157
230
|
raise TypeError(
|
|
158
231
|
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
|
|
@@ -181,7 +254,8 @@ class Shard(Shard_):
|
|
|
181
254
|
"will be overwritten as False.")
|
|
182
255
|
ms.set_algo_parameters(fully_use_devices=False)
|
|
183
256
|
|
|
184
|
-
if ms.context.get_auto_parallel_context("full_batch_is_set") is False
|
|
257
|
+
if ms.context.get_auto_parallel_context("full_batch_is_set") is False and \
|
|
258
|
+
ms.context.get_context("mode") == ms.context.PYNATIVE_MODE:
|
|
185
259
|
logger.warning("When calling the shard interface, "
|
|
186
260
|
"'dataset_strategy' or 'full_batch' is not manually set by the user, "
|
|
187
261
|
"and the 'dataset_strategy' will be set to 'full_batch'.")
|
|
@@ -193,13 +267,13 @@ class Shard(Shard_):
|
|
|
193
267
|
|
|
194
268
|
if isinstance(fn, ms.nn.Cell):
|
|
195
269
|
for param in fn.trainable_params():
|
|
196
|
-
param.
|
|
270
|
+
param.param_info.is_in_pynative_shard = True
|
|
197
271
|
|
|
198
272
|
# Set parameter layout to corresponding parameter
|
|
199
273
|
self._set_param_layout_into_parameter(fn, parameter_plan)
|
|
200
274
|
|
|
201
275
|
def shard_fn(*args):
|
|
202
|
-
@ms.common.jit(hash_args=fn)
|
|
276
|
+
@ms.common.jit(hash_args=fn, backend="ms_backend")
|
|
203
277
|
def after_shard(*args):
|
|
204
278
|
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
|
205
279
|
|
|
@@ -290,7 +364,7 @@ class Shard(Shard_):
|
|
|
290
364
|
for stra in strategy:
|
|
291
365
|
if not isinstance(stra, (tuple, Layout)):
|
|
292
366
|
raise TypeError(
|
|
293
|
-
f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
|
|
367
|
+
f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.parallel.Layout), "
|
|
294
368
|
f"but got {type(stra).__name__}")
|
|
295
369
|
if isinstance(stra, Layout):
|
|
296
370
|
strategy_set.add("layout")
|
|
@@ -312,7 +386,7 @@ class Shard(Shard_):
|
|
|
312
386
|
for in_ele in layout:
|
|
313
387
|
if not isinstance(in_ele, Layout):
|
|
314
388
|
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
|
315
|
-
layout_value += (in_ele.to_dict(),)
|
|
389
|
+
layout_value += ({k: v for k, v in in_ele.to_dict().items() if k != "rank_list"},)
|
|
316
390
|
return layout_value
|
|
317
391
|
|
|
318
392
|
def _check_tuple_strategy(self, dim_strategy):
|
|
@@ -323,8 +397,8 @@ class Shard(Shard_):
|
|
|
323
397
|
|
|
324
398
|
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
325
399
|
"""
|
|
326
|
-
|
|
327
|
-
|
|
400
|
+
Specify the input and output slicing strategy for a Cell or function.
|
|
401
|
+
In PyNative mode, use this method to specify a Cell for distributed
|
|
328
402
|
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
329
403
|
strategy for others will be set by sharding propagation.
|
|
330
404
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
@@ -334,33 +408,37 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
334
408
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
335
409
|
|
|
336
410
|
Note:
|
|
337
|
-
If
|
|
338
|
-
|
|
339
|
-
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
411
|
+
- If shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
|
|
412
|
+
"auto_parallel" and the search mode (search_mode) to "sharding_propagation".
|
|
413
|
+
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
414
|
+
- This method currently does not support dynamic shapes.
|
|
340
415
|
|
|
341
416
|
Args:
|
|
342
417
|
fn (Union[Cell, Function]): Function to be executed in parallel.
|
|
343
|
-
Its arguments and return value must be Tensor
|
|
418
|
+
Its arguments and return value must be Tensor.
|
|
344
419
|
If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
|
|
345
420
|
otherwise its arguments cannot be accessed.
|
|
346
421
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
|
|
347
|
-
tuple(mindspore.Layout).
|
|
422
|
+
tuple(mindspore.parallel.Layout).
|
|
348
423
|
Tuple defines the layout of the corresponding input.
|
|
349
|
-
out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
|
|
350
|
-
|
|
351
|
-
parameter_plan (Union[dict, None]): Define the layout for the specified parameters.
|
|
424
|
+
out_strategy (Union[tuple, None], optional): Define the layout of outputs similar with `in_strategy`.
|
|
425
|
+
Default: ``None`` .
|
|
426
|
+
parameter_plan (Union[dict, None], optional): Define the layout for the specified parameters.
|
|
427
|
+
Each element in dict
|
|
352
428
|
defines the layout of the parameter like "param_name: layout".
|
|
353
429
|
The key is a parameter name of type 'str'.
|
|
354
|
-
The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
|
|
430
|
+
The value is a 1-D integer tuple or a 1-D mindspore.parallel.Layout tuple,
|
|
355
431
|
indicating the corresponding layout.
|
|
356
432
|
If the parameter name is incorrect or the corresponding parameter
|
|
357
|
-
has been set, the parameter setting will be ignored.
|
|
433
|
+
has been set, the parameter setting will be ignored. Supported
|
|
434
|
+
only when `fn` is a Cell with parameters.
|
|
358
435
|
Default: ``None`` .
|
|
359
|
-
device (
|
|
360
|
-
|
|
361
|
-
level (int): Option for parallel strategy infer algorithm, namely the object function,
|
|
362
|
-
|
|
363
|
-
|
|
436
|
+
device (str, optional): Select a certain `device` target. It is not in use right now.
|
|
437
|
+
Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
|
|
438
|
+
level (int, optional): Option for parallel strategy infer algorithm, namely the object function,
|
|
439
|
+
maximize computation
|
|
440
|
+
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
441
|
+
use right now. Support [0, 1, 2]. Default: ``0`` .
|
|
364
442
|
|
|
365
443
|
Returns:
|
|
366
444
|
Function, return the function that will be executed under auto parallel process.
|
|
@@ -370,26 +448,28 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
370
448
|
AssertionError: If device_target it not "Ascend" or "GPU".
|
|
371
449
|
TypeError: If `in_strategy` is not a tuple.
|
|
372
450
|
TypeError: If `out_strategy` is not a tuple or None.
|
|
373
|
-
TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
374
|
-
TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
451
|
+
TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
|
|
452
|
+
TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
|
|
375
453
|
TypeError: If `parameter_plan` is not a dict or None.
|
|
376
454
|
TypeError: If any key in `parameter_plan` is not a str.
|
|
377
|
-
TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
|
|
455
|
+
TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.parallel.Layout).
|
|
378
456
|
TypeError: If `device` is not a str.
|
|
379
457
|
TypeError: If `level` is not an integer.
|
|
380
458
|
|
|
381
459
|
Supported Platforms:
|
|
382
|
-
``Ascend``
|
|
460
|
+
``Ascend``
|
|
383
461
|
|
|
384
462
|
Examples:
|
|
385
463
|
>>> import numpy as np
|
|
386
464
|
>>> import mindspore as ms
|
|
387
|
-
>>> from mindspore import Tensor, nn
|
|
465
|
+
>>> from mindspore import Tensor, nn, ops
|
|
388
466
|
>>> from mindspore.communication import init
|
|
467
|
+
>>> from mindspore.parallel import shard
|
|
468
|
+
>>> from mindspore.parallel import Layout
|
|
469
|
+
>>> from mindspore.nn.utils import no_init_parameters
|
|
470
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
389
471
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
390
472
|
>>> init()
|
|
391
|
-
>>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
|
|
392
|
-
... device_num=8)
|
|
393
473
|
>>>
|
|
394
474
|
>>> # Case 1: cell uses functional
|
|
395
475
|
>>> class BasicBlock(nn.Cell):
|
|
@@ -401,7 +481,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
401
481
|
>>> x = ops.abs(x)
|
|
402
482
|
>>> return x + y
|
|
403
483
|
>>> # shard a function with tuple(int) strategies
|
|
404
|
-
>>> self.shard_my_add =
|
|
484
|
+
>>> self.shard_my_add = shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
|
|
405
485
|
>>>
|
|
406
486
|
>>> def construct(self, x, u):
|
|
407
487
|
>>> x = self.gelu(x)
|
|
@@ -429,7 +509,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
429
509
|
>>> super(Net, self).__init__()
|
|
430
510
|
>>> # setting cell sharding strategy and parameter_plan by tuple(int)
|
|
431
511
|
>>> self.layer_net1 = NetForward()
|
|
432
|
-
>>> self.layer_net1_shard =
|
|
512
|
+
>>> self.layer_net1_shard = shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
|
|
433
513
|
... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
|
|
434
514
|
>>>
|
|
435
515
|
>>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
|
|
@@ -437,7 +517,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
437
517
|
>>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
|
|
438
518
|
>>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
|
|
439
519
|
>>> param_layout = layout("dp", "sp")
|
|
440
|
-
>>> self.layer_net2_shard =
|
|
520
|
+
>>> self.layer_net2_shard = shard(self.layer_net2, in_strategy=in_layout,
|
|
441
521
|
... parameter_plan={"self.layer_net2.block2.weight": param_layout})
|
|
442
522
|
>>> self.flatten = nn.Flatten()
|
|
443
523
|
>>> self.layer1 = nn.Dense(64, 64)
|
|
@@ -455,26 +535,25 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
455
535
|
>>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
|
|
456
536
|
>>> return x
|
|
457
537
|
>>>
|
|
458
|
-
>>>
|
|
538
|
+
>>> with no_init_parameters():
|
|
539
|
+
>>> net = Net()
|
|
459
540
|
>>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
460
541
|
>>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
461
|
-
>>> net
|
|
542
|
+
>>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation')
|
|
543
|
+
>>> parallel_net(x, y)
|
|
462
544
|
>>>
|
|
463
545
|
>>> # Case 2: function uses functional sharding
|
|
464
546
|
>>> def test_shard(x, y):
|
|
465
547
|
... return x + y
|
|
466
548
|
>>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
467
549
|
>>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
468
|
-
>>> output =
|
|
550
|
+
>>> output = shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
|
|
469
551
|
>>> print(output.shape)
|
|
470
552
|
(32, 10)
|
|
471
553
|
|
|
472
|
-
Tutorial Examples:
|
|
473
|
-
- `Functional Operator Sharding
|
|
474
|
-
<https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
|
|
475
|
-
- `mindspore.Layout
|
|
476
|
-
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
|
|
477
554
|
"""
|
|
555
|
+
if ms.communication.management.get_group_size() == 1:
|
|
556
|
+
return fn
|
|
478
557
|
if not isinstance(fn, (ms.nn.Cell)):
|
|
479
558
|
logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
|
|
480
559
|
"otherwise, the result may be incorrect.")
|