mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,6 +21,7 @@ from mindspore import context
|
|
|
21
21
|
import mindspore.log as logger
|
|
22
22
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
|
23
23
|
from mindspore.parallel._ps_context import _is_role_pserver
|
|
24
|
+
from mindspore.parallel.shard import Layout
|
|
24
25
|
from mindspore._c_expression import AutoParallelContext
|
|
25
26
|
from mindspore._checkparam import args_type_check
|
|
26
27
|
from mindspore import _checkparam as Validator
|
|
@@ -63,6 +64,7 @@ class _ParallelOptimizerConfig:
|
|
|
63
64
|
GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
|
|
64
65
|
PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
|
|
65
66
|
OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
|
|
67
|
+
OPTIMIZER_LEVEL = "optimizer_level"
|
|
66
68
|
|
|
67
69
|
|
|
68
70
|
class _PipelineConfig:
|
|
@@ -77,6 +79,8 @@ class _PipelineScheduler:
|
|
|
77
79
|
PIPELINE_1F1B = "1f1b"
|
|
78
80
|
PIPELINE_GPIPE = "gpipe"
|
|
79
81
|
PIPELINE_SEQPIPE = "seqpipe"
|
|
82
|
+
PIPELINE_SEQVPP = "seqvpp"
|
|
83
|
+
PIPELINE_SEQSMARTVPP = "seqsmartvpp"
|
|
80
84
|
|
|
81
85
|
|
|
82
86
|
class _AutoParallelContext:
|
|
@@ -100,6 +104,7 @@ class _AutoParallelContext:
|
|
|
100
104
|
def __init__(self):
|
|
101
105
|
self._context_handle = AutoParallelContext.get_instance()
|
|
102
106
|
self._dataset_strategy_using_str = True
|
|
107
|
+
self._dataset_layout = None
|
|
103
108
|
|
|
104
109
|
def check_context_handle(self):
|
|
105
110
|
"""
|
|
@@ -187,6 +192,25 @@ class _AutoParallelContext:
|
|
|
187
192
|
self.check_context_handle()
|
|
188
193
|
return self._context_handle.get_dump_local_norm()
|
|
189
194
|
|
|
195
|
+
def set_dump_local_norm_path(self, dump_local_norm_path):
|
|
196
|
+
"""
|
|
197
|
+
Set dump local norm path for auto parallel.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
dump_local_norm_path (str): User need to specify the path to save dump files
|
|
201
|
+
if he want to dump local norm. Default: ''
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
KeyError: When the value of dump_local_norm_path is not a str value.
|
|
205
|
+
"""
|
|
206
|
+
self.check_context_handle()
|
|
207
|
+
self._context_handle.set_dump_local_norm_path(dump_local_norm_path)
|
|
208
|
+
|
|
209
|
+
def get_dump_local_norm_path(self):
|
|
210
|
+
"""Get dump local norm path."""
|
|
211
|
+
self.check_context_handle()
|
|
212
|
+
return self._context_handle.get_dump_local_norm_path()
|
|
213
|
+
|
|
190
214
|
def set_dump_device_local_norm(self, dump_device_local_norm):
|
|
191
215
|
"""
|
|
192
216
|
Set dump device local norm for auto parallel.
|
|
@@ -195,7 +219,7 @@ class _AutoParallelContext:
|
|
|
195
219
|
dump_device_local_norm (bool): User need to specify if he want to dump device local norm. Default: False
|
|
196
220
|
|
|
197
221
|
Raises:
|
|
198
|
-
ValueError: If the dump_device_local_norm
|
|
222
|
+
ValueError: If the dump_device_local_norm is not a bool value.
|
|
199
223
|
"""
|
|
200
224
|
self.check_context_handle()
|
|
201
225
|
self._context_handle.set_dump_device_local_norm(dump_device_local_norm)
|
|
@@ -422,6 +446,9 @@ class _AutoParallelContext:
|
|
|
422
446
|
raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', "
|
|
423
447
|
"'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' and 'auto_parallel', "
|
|
424
448
|
"but got the value : {}.".format(parallel_mode))
|
|
449
|
+
if run_mode == context.ParallelMode.DATA_PARALLEL and self.get_enable_parallel_optimizer():
|
|
450
|
+
logger.warning("'enable_parallel_optimizer' is not suggested in 'data_parallel' mode, "
|
|
451
|
+
"consider using 'semi_auto_parallel' or 'auto_parallel' mode.")
|
|
425
452
|
|
|
426
453
|
def get_parallel_mode(self):
|
|
427
454
|
"""Get parallel mode."""
|
|
@@ -566,6 +593,9 @@ class _AutoParallelContext:
|
|
|
566
593
|
if not isinstance(dataset_strategy, tuple):
|
|
567
594
|
raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
|
|
568
595
|
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
|
|
596
|
+
if dataset_strategy and isinstance(dataset_strategy[0], Layout):
|
|
597
|
+
self._set_dataset_strategy_layout(dataset_strategy)
|
|
598
|
+
return
|
|
569
599
|
for ele in dataset_strategy:
|
|
570
600
|
if not isinstance(ele, tuple):
|
|
571
601
|
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
|
@@ -580,8 +610,36 @@ class _AutoParallelContext:
|
|
|
580
610
|
self._dataset_strategy_using_str = False
|
|
581
611
|
self._context_handle.set_dataset_strategy(dataset_strategy)
|
|
582
612
|
|
|
613
|
+
def _set_dataset_strategy_layout(self, dataset_strategy):
|
|
614
|
+
"""set dataset layout to c++ by using pybind."""
|
|
615
|
+
dataset_devmat = []
|
|
616
|
+
dataset_tensormap = []
|
|
617
|
+
dataset_alias_name = []
|
|
618
|
+
self._dataset_layout = dataset_strategy
|
|
619
|
+
for ele in dataset_strategy:
|
|
620
|
+
if not isinstance(ele, Layout):
|
|
621
|
+
raise TypeError(f"All the dataset_strategy elements should be Layout, but got {type(ele)}")
|
|
622
|
+
layout_to_dict = ele.to_dict()
|
|
623
|
+
dataset_devmat.append(layout_to_dict["device_matrix"])
|
|
624
|
+
dataset_alias_name.append(layout_to_dict["alias_name"])
|
|
625
|
+
if layout_to_dict["interleaved_parallel"]:
|
|
626
|
+
raise ValueError("For dataset_strategy, layout does not support interleaved_parallel")
|
|
627
|
+
tensor_map = []
|
|
628
|
+
for value in layout_to_dict["tensor_map"]:
|
|
629
|
+
if isinstance(value, tuple):
|
|
630
|
+
tensor_map.append(value)
|
|
631
|
+
elif isinstance(value, int):
|
|
632
|
+
tensor_map.append((value,))
|
|
633
|
+
else:
|
|
634
|
+
raise TypeError(f"value in tensor map must be tuple or int, but got {type(value)}")
|
|
635
|
+
dataset_tensormap.append(tuple(tensor_map))
|
|
636
|
+
self._context_handle.set_dataset_layout(dataset_devmat, dataset_tensormap, dataset_alias_name)
|
|
637
|
+
|
|
638
|
+
|
|
583
639
|
def get_dataset_strategy(self):
|
|
584
640
|
"""Get dataset sharding strategy."""
|
|
641
|
+
if self._dataset_layout is not None:
|
|
642
|
+
return self._dataset_layout
|
|
585
643
|
self.check_context_handle()
|
|
586
644
|
if self._dataset_strategy_using_str:
|
|
587
645
|
if self._context_handle.get_full_batch():
|
|
@@ -869,6 +927,9 @@ class _AutoParallelContext:
|
|
|
869
927
|
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
|
|
870
928
|
.format(type(enable_parallel_optimizer)))
|
|
871
929
|
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
|
930
|
+
if enable_parallel_optimizer and self.get_parallel_mode() == context.ParallelMode.DATA_PARALLEL:
|
|
931
|
+
logger.warning("'enable_parallel_optimizer' is not suggested in 'data_parallel' mode, "
|
|
932
|
+
"consider using 'semi_auto_parallel' or 'auto_parallel' mode.")
|
|
872
933
|
|
|
873
934
|
def set_force_fp32_communication(self, force_fp32_communication):
|
|
874
935
|
"""
|
|
@@ -899,7 +960,7 @@ class _AutoParallelContext:
|
|
|
899
960
|
|
|
900
961
|
- pipeline_interleave(bool): Setting true enable interleave scheduler for pipeline parallelism. This
|
|
901
962
|
scheduler requires more memory but less bubble.
|
|
902
|
-
- pipeline_scheduler(
|
|
963
|
+
- pipeline_scheduler(str): There are two choices, "1f1b" and "gpipe". default is "1f1b"
|
|
903
964
|
|
|
904
965
|
- 1f1b: It requires less memory and bubble ratio, for it run backward pass when corresponding forward pass
|
|
905
966
|
finished.
|
|
@@ -935,7 +996,9 @@ class _AutoParallelContext:
|
|
|
935
996
|
|
|
936
997
|
Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
|
|
937
998
|
_PipelineScheduler.PIPELINE_GPIPE,
|
|
938
|
-
_PipelineScheduler.PIPELINE_SEQPIPE
|
|
999
|
+
_PipelineScheduler.PIPELINE_SEQPIPE,
|
|
1000
|
+
_PipelineScheduler.PIPELINE_SEQVPP,
|
|
1001
|
+
_PipelineScheduler.PIPELINE_SEQSMARTVPP])
|
|
939
1002
|
if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
|
|
940
1003
|
raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
|
|
941
1004
|
|
|
@@ -975,19 +1038,21 @@ class _AutoParallelContext:
|
|
|
975
1038
|
shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
|
|
976
1039
|
- optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
|
|
977
1040
|
maximum group size across devices when the parallel optimizer is
|
|
978
|
-
enabled. The numerical range can be (0, device_num].
|
|
979
|
-
is
|
|
980
|
-
|
|
981
|
-
|
|
1041
|
+
enabled. The numerical range can be (0, device_num] or -1. If pipeline
|
|
1042
|
+
parallelism is enabled, the numerical range is (0, device_num/stage]
|
|
1043
|
+
or -1. Default value is -1, which means the optimizer weight shard
|
|
1044
|
+
group size will be equal to the data parallel group of each parameter.
|
|
982
1045
|
"""
|
|
983
1046
|
self.check_context_handle()
|
|
984
1047
|
grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
|
|
985
1048
|
threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
|
|
986
1049
|
optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
|
|
1050
|
+
optimizer_level_name = _ParallelOptimizerConfig.OPTIMIZER_LEVEL
|
|
987
1051
|
|
|
988
1052
|
for config_name in parallel_optimizer_config:
|
|
989
1053
|
unknown_config = []
|
|
990
|
-
if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name
|
|
1054
|
+
if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name,
|
|
1055
|
+
optimizer_level_name]:
|
|
991
1056
|
unknown_config.append(config_name)
|
|
992
1057
|
|
|
993
1058
|
if unknown_config:
|
|
@@ -998,6 +1063,10 @@ class _AutoParallelContext:
|
|
|
998
1063
|
parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name)
|
|
999
1064
|
self._context_handle.set_grad_accumulation_shard(
|
|
1000
1065
|
parallel_optimizer_config[grad_shard_name])
|
|
1066
|
+
if optimizer_level_name in parallel_optimizer_config \
|
|
1067
|
+
and parallel_optimizer_config[optimizer_level_name] != "level2":
|
|
1068
|
+
raise ValueError(f"The optimizer_level is set as {parallel_optimizer_config[optimizer_level_name]}, "
|
|
1069
|
+
"thus cannot set grad_accumulation_shard as True.")
|
|
1001
1070
|
|
|
1002
1071
|
if threshold_name in parallel_optimizer_config:
|
|
1003
1072
|
Validator.check_non_negative_int(
|
|
@@ -1007,8 +1076,23 @@ class _AutoParallelContext:
|
|
|
1007
1076
|
|
|
1008
1077
|
if optimizer_weight_shard_size_name in parallel_optimizer_config:
|
|
1009
1078
|
value = parallel_optimizer_config[optimizer_weight_shard_size_name]
|
|
1010
|
-
|
|
1011
|
-
|
|
1079
|
+
if value != -1:
|
|
1080
|
+
Validator.check_positive_int(value, prim_name="optimizer_weight_shard_size")
|
|
1081
|
+
self.set_optimizer_weight_shard_size(value)
|
|
1082
|
+
|
|
1083
|
+
if optimizer_level_name in parallel_optimizer_config:
|
|
1084
|
+
optimizer_level = parallel_optimizer_config[optimizer_level_name]
|
|
1085
|
+
if optimizer_level not in ["level1", "level2", "level3"]:
|
|
1086
|
+
raise ValueError("Optimizer level should in ['level1', 'level2', 'level3'], but got {}"
|
|
1087
|
+
.format(optimizer_level))
|
|
1088
|
+
|
|
1089
|
+
if self._context_handle.get_grad_accumulation_shard() and optimizer_level != "level2":
|
|
1090
|
+
raise ValueError("The grad_accumulation shard is set, thus cannot set optimizer_level != 'level2'")
|
|
1091
|
+
if optimizer_level == "level2":
|
|
1092
|
+
self._context_handle.set_grad_accumulation_shard(True)
|
|
1093
|
+
if optimizer_level == "level3":
|
|
1094
|
+
self._context_handle.set_zero3(True)
|
|
1095
|
+
self._context_handle.set_grad_accumulation_shard(False)
|
|
1012
1096
|
|
|
1013
1097
|
def get_grad_accumulation_shard(self):
|
|
1014
1098
|
"""Get grad accumulation shard."""
|
|
@@ -1089,13 +1173,6 @@ class _AutoParallelContext:
|
|
|
1089
1173
|
self.check_context_handle()
|
|
1090
1174
|
return self._context_handle.get_optimizer_weight_shard_size()
|
|
1091
1175
|
|
|
1092
|
-
def set_ops_strategy_json_config(self, type, path, mode):
|
|
1093
|
-
"""
|
|
1094
|
-
Set configuration of saving ops strategy in file .json.
|
|
1095
|
-
"""
|
|
1096
|
-
self.check_context_handle()
|
|
1097
|
-
self._context_handle.set_ops_strategy_json_config(type, path, mode)
|
|
1098
|
-
|
|
1099
1176
|
def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
|
|
1100
1177
|
"""
|
|
1101
1178
|
Set optimizer_weight_shard_aggregated_save.
|
|
@@ -1124,6 +1201,7 @@ class _AutoParallelContext:
|
|
|
1124
1201
|
self.check_context_handle()
|
|
1125
1202
|
self._context_handle.reset()
|
|
1126
1203
|
_ParallelFusionConfig.reset()
|
|
1204
|
+
self._dataset_layout = None
|
|
1127
1205
|
|
|
1128
1206
|
def _check_and_default_group(self, group):
|
|
1129
1207
|
"""Validate the given group, if group is empty, returns a default fusion group"""
|
|
@@ -1233,30 +1311,35 @@ class _AutoParallelContext:
|
|
|
1233
1311
|
self.set_enable_all_gather_fusion(openstate)
|
|
1234
1312
|
self.set_enable_reduce_scatter_fusion(openstate)
|
|
1235
1313
|
|
|
1314
|
+
def set_auto_parallel_new_interface(self, auto_parallel_new_interface):
|
|
1315
|
+
"""
|
|
1316
|
+
Set AutoParallel(cell) new interface flag.
|
|
1317
|
+
|
|
1318
|
+
Args:
|
|
1319
|
+
auto_parallel_new_interface (bool): Mark whether to use the new interface.
|
|
1320
|
+
"""
|
|
1321
|
+
self.check_context_handle()
|
|
1322
|
+
self._context_handle.set_auto_parallel_new_interface(auto_parallel_new_interface)
|
|
1236
1323
|
|
|
1237
|
-
def
|
|
1238
|
-
|
|
1239
|
-
|
|
1324
|
+
def get_auto_parallel_new_interface(self):
|
|
1325
|
+
"""Get auto_parallel_new_interface."""
|
|
1326
|
+
self.check_context_handle()
|
|
1327
|
+
return self._context_handle.get_auto_parallel_new_interface()
|
|
1240
1328
|
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
mode (str): The parameter for choosing save all or important operators.
|
|
1329
|
+
def set_init_param_in_compile(self, init_param_in_compile):
|
|
1330
|
+
"""
|
|
1331
|
+
Set flag marking whether to init parameters in compiling process.
|
|
1245
1332
|
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
if dir_path and not os.path.exists(dir_path):
|
|
1252
|
-
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
1253
|
-
check_type = ["SAVE", "LOAD"]
|
|
1254
|
-
check_mode = ["all", "principal"]
|
|
1255
|
-
if type in check_type and mode in check_mode:
|
|
1256
|
-
auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
|
|
1257
|
-
else:
|
|
1258
|
-
raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
|
|
1333
|
+
Args:
|
|
1334
|
+
init_param_in_compile (bool): Mark whether to init parameters in compiling process.
|
|
1335
|
+
"""
|
|
1336
|
+
self.check_context_handle()
|
|
1337
|
+
self._context_handle.set_init_param_in_compile(init_param_in_compile)
|
|
1259
1338
|
|
|
1339
|
+
def get_init_param_in_compile(self):
|
|
1340
|
+
"""Get init_param_in_compile."""
|
|
1341
|
+
self.check_context_handle()
|
|
1342
|
+
return self._context_handle.get_init_param_in_compile()
|
|
1260
1343
|
|
|
1261
1344
|
_AUTO_PARALLEL_CONTEXT = None
|
|
1262
1345
|
|
|
@@ -1307,7 +1390,11 @@ _set_auto_parallel_context_func_map = {
|
|
|
1307
1390
|
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
|
1308
1391
|
"comm_fusion": auto_parallel_context().set_comm_fusion,
|
|
1309
1392
|
"dump_local_norm": auto_parallel_context().set_dump_local_norm,
|
|
1310
|
-
"
|
|
1393
|
+
"dump_local_norm_path": auto_parallel_context().set_dump_local_norm_path,
|
|
1394
|
+
"dump_device_local_norm": auto_parallel_context().set_dump_device_local_norm,
|
|
1395
|
+
"auto_parallel_new_interface": auto_parallel_context().set_auto_parallel_new_interface,
|
|
1396
|
+
"init_param_in_compile": auto_parallel_context().set_init_param_in_compile}
|
|
1397
|
+
|
|
1311
1398
|
|
|
1312
1399
|
_get_auto_parallel_context_func_map = {
|
|
1313
1400
|
"device_num": auto_parallel_context().get_device_num,
|
|
@@ -1341,7 +1428,10 @@ _get_auto_parallel_context_func_map = {
|
|
|
1341
1428
|
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
|
|
1342
1429
|
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
|
|
1343
1430
|
"dump_local_norm": auto_parallel_context().get_dump_local_norm,
|
|
1344
|
-
"
|
|
1431
|
+
"dump_local_norm_path": auto_parallel_context().get_dump_local_norm_path,
|
|
1432
|
+
"dump_device_local_norm": auto_parallel_context().get_dump_device_local_norm,
|
|
1433
|
+
"auto_parallel_new_interface": auto_parallel_context().get_auto_parallel_new_interface,
|
|
1434
|
+
"init_param_in_compile": auto_parallel_context().get_init_param_in_compile}
|
|
1345
1435
|
|
|
1346
1436
|
|
|
1347
1437
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
@@ -1452,6 +1542,8 @@ def _set_auto_parallel_context(**kwargs):
|
|
|
1452
1542
|
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
|
|
1453
1543
|
and `size`. Config is same as `allgather`.
|
|
1454
1544
|
|
|
1545
|
+
|
|
1546
|
+
|
|
1455
1547
|
Raises:
|
|
1456
1548
|
ValueError: If input key is not attribute in auto parallel context.
|
|
1457
1549
|
"""
|
|
@@ -1481,6 +1573,33 @@ def _get_auto_parallel_context(attr_key):
|
|
|
1481
1573
|
return get_func()
|
|
1482
1574
|
|
|
1483
1575
|
|
|
1576
|
+
def _get_all_auto_parallel_context():
|
|
1577
|
+
"""get auto parallel context before reset"""
|
|
1578
|
+
_auto_paralell_context_value_map = {}
|
|
1579
|
+
_pipeline_config = {}
|
|
1580
|
+
for key, value in _get_auto_parallel_context_func_map.items():
|
|
1581
|
+
if key == "pipeline_interleave":
|
|
1582
|
+
_pipeline_config[key] = value()
|
|
1583
|
+
elif key == "pipeline_scheduler":
|
|
1584
|
+
_pipeline_config[key] = value()
|
|
1585
|
+
else:
|
|
1586
|
+
_auto_paralell_context_value_map[key] = value()
|
|
1587
|
+
return _auto_paralell_context_value_map, _pipeline_config
|
|
1588
|
+
|
|
1589
|
+
|
|
1590
|
+
def _recover_auto_parallel_context(context_value_map, pp_config):
|
|
1591
|
+
"""set auto parallel context after transformation"""
|
|
1592
|
+
# set the same auto parallel context after transform
|
|
1593
|
+
from mindspore.context import reset_auto_parallel_context
|
|
1594
|
+
reset_auto_parallel_context()
|
|
1595
|
+
for key, value in context_value_map.items():
|
|
1596
|
+
# list is empty or full_batch_is_set is not needed to set
|
|
1597
|
+
if (isinstance(value, list) and not value) or (key == "full_batch_is_set"):
|
|
1598
|
+
continue
|
|
1599
|
+
_set_auto_parallel_context_func_map[key](value)
|
|
1600
|
+
_set_auto_parallel_context_func_map["pipeline_config"](pp_config)
|
|
1601
|
+
|
|
1602
|
+
|
|
1484
1603
|
def _reset_auto_parallel_context():
|
|
1485
1604
|
"""
|
|
1486
1605
|
Reset auto parallel context attributes to the default values:
|
|
@@ -18,17 +18,20 @@ from __future__ import division
|
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
21
|
+
import mindspore.log as logger
|
|
21
22
|
from mindspore import context
|
|
22
23
|
from mindspore.nn.cell import Cell
|
|
23
24
|
from mindspore.ops import operations as P
|
|
24
25
|
from mindspore.ops.operations.comm_ops import AllGather
|
|
25
|
-
from mindspore.communication import GlobalComm
|
|
26
|
+
from mindspore.communication import GlobalComm, get_rank
|
|
26
27
|
from mindspore.common import jit
|
|
27
|
-
from mindspore.communication import create_group, destroy_group
|
|
28
|
+
from mindspore.communication import create_group, destroy_group, get_group_size
|
|
28
29
|
from mindspore.communication._comm_helper import _get_group_map
|
|
29
30
|
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
31
|
+
from mindspore.parallel.shard import Layout
|
|
30
32
|
|
|
31
33
|
_ALLGATHER_CELL = None
|
|
34
|
+
ALLREDUCE_GROUP_LIST = []
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
class AllGatherCell(Cell):
|
|
@@ -134,7 +137,7 @@ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
|
|
|
134
137
|
|
|
135
138
|
def _get_group_name(group_map, group):
|
|
136
139
|
"""get group name"""
|
|
137
|
-
group_name = str(group)
|
|
140
|
+
group_name = "remove_redundancy" + str(group)
|
|
138
141
|
is_manual_communication_group = True
|
|
139
142
|
if group_map:
|
|
140
143
|
for name, rank_list in group_map.items():
|
|
@@ -142,20 +145,37 @@ def _get_group_name(group_map, group):
|
|
|
142
145
|
group_name = name
|
|
143
146
|
is_manual_communication_group = False
|
|
144
147
|
break
|
|
145
|
-
if is_manual_communication_group:
|
|
146
|
-
create_group(str(group), list(group))
|
|
147
148
|
return group_name, is_manual_communication_group
|
|
148
149
|
|
|
149
150
|
|
|
150
|
-
def
|
|
151
|
+
def _get_param_redundancy_reversed(param_redundancy, cur_rank):
|
|
152
|
+
"""Generate the reverse mapping of parameter redundancy based on the current rank."""
|
|
153
|
+
param_redundancy_reversed = {}
|
|
154
|
+
for key, redundancy in param_redundancy.items():
|
|
155
|
+
for item in redundancy:
|
|
156
|
+
if len(item) == 1:
|
|
157
|
+
continue
|
|
158
|
+
if cur_rank in item:
|
|
159
|
+
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
160
|
+
return param_redundancy_reversed
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _remove_param_not_load(param_name, param_not_load):
|
|
164
|
+
"""Remove param_name from param_not_load."""
|
|
165
|
+
if param_not_load is not None and param_name in param_not_load:
|
|
166
|
+
param_not_load.remove(param_name)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _single_parameter_broadcast(net, layout, param_not_load=None):
|
|
151
170
|
"""
|
|
152
171
|
Broadcast single parameter to other rank in data parallel dimension.
|
|
153
172
|
"""
|
|
154
173
|
from mindspore import Tensor
|
|
155
174
|
origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
156
175
|
origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
|
|
176
|
+
cur_rank = get_rank()
|
|
157
177
|
if layout:
|
|
158
|
-
param_redundancy = get_parameter_redundancy(layout
|
|
178
|
+
param_redundancy = get_parameter_redundancy(layout)
|
|
159
179
|
else:
|
|
160
180
|
param_redundancy = get_parameter_redundancy(net)
|
|
161
181
|
if not param_redundancy:
|
|
@@ -163,33 +183,130 @@ def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
163
183
|
single_params = remove_param_redundancy(param_redundancy)
|
|
164
184
|
if not single_params:
|
|
165
185
|
return
|
|
166
|
-
param_redundancy_reversed =
|
|
167
|
-
for key, redundancy in param_redundancy.items():
|
|
168
|
-
for item in redundancy:
|
|
169
|
-
if len(item) == 1:
|
|
170
|
-
continue
|
|
171
|
-
if cur_rank in item:
|
|
172
|
-
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
186
|
+
param_redundancy_reversed = _get_param_redundancy_reversed(param_redundancy, cur_rank)
|
|
173
187
|
if not param_redundancy_reversed or cur_rank not in single_params:
|
|
174
188
|
return
|
|
175
189
|
net_param_dict = net.parameters_dict()
|
|
176
190
|
_chang_parallel_context(origin_dataset_strategy)
|
|
177
191
|
group_map = _get_group_map()
|
|
192
|
+
if group_map:
|
|
193
|
+
group_map = {key: group_map[key] for key in sorted(group_map.keys())}
|
|
178
194
|
for group, params in param_redundancy_reversed.items():
|
|
179
195
|
group_name, is_manual_communication_group = _get_group_name(group_map, group)
|
|
180
196
|
allreduce_input = []
|
|
181
197
|
for param in params:
|
|
182
198
|
if param not in net_param_dict:
|
|
183
199
|
continue
|
|
200
|
+
if param.startswith("accu_grads") or param.endswith("expert_load"):
|
|
201
|
+
continue
|
|
184
202
|
real_param = net_param_dict[param]
|
|
203
|
+
_remove_param_not_load(real_param.name, param_not_load)
|
|
185
204
|
if param not in single_params[cur_rank]:
|
|
186
205
|
real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
|
|
187
206
|
allreduce_input.append(real_param)
|
|
188
207
|
if not allreduce_input:
|
|
189
208
|
continue
|
|
209
|
+
if is_manual_communication_group:
|
|
210
|
+
create_group(group_name, list(group))
|
|
211
|
+
allreduce_input.sort(key=lambda param: (str(param.shape), str(param.dtype)))
|
|
190
212
|
communicator = SingleCommunicator(group_name)
|
|
191
213
|
for real_param in allreduce_input:
|
|
192
|
-
real_param.set_data(communicator(real_param), real_param.sliced)
|
|
214
|
+
real_param.set_data(communicator(Tensor(real_param)), real_param.sliced)
|
|
193
215
|
if is_manual_communication_group:
|
|
194
216
|
destroy_group(group_name)
|
|
195
217
|
_restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _insert_virtual_pp_dim(layout):
|
|
221
|
+
"""insert virtual pp dim in device matrix and create new layout"""
|
|
222
|
+
if len(layout.to_dict()["rank_list"]) == get_group_size():
|
|
223
|
+
return layout
|
|
224
|
+
remain_pp = get_group_size() // len(layout.to_dict()["rank_list"])
|
|
225
|
+
layout_info = layout.to_dict()
|
|
226
|
+
device_matrix = layout_info["device_matrix"]
|
|
227
|
+
tensor_map = layout_info["tensor_map"]
|
|
228
|
+
alias_name = layout_info["alias_name"]
|
|
229
|
+
new_devmat = Layout((remain_pp,) + device_matrix, ("remain_pp",) + alias_name)
|
|
230
|
+
tensor_map_alias_name = []
|
|
231
|
+
for val in tensor_map:
|
|
232
|
+
sub_alias_name = []
|
|
233
|
+
if isinstance(val, tuple):
|
|
234
|
+
for sub_val in val:
|
|
235
|
+
if sub_val == -1:
|
|
236
|
+
sub_alias_name.append("None")
|
|
237
|
+
else:
|
|
238
|
+
sub_alias_name.append(alias_name[len(device_matrix) - sub_val - 1])
|
|
239
|
+
tensor_map_alias_name.append(tuple(sub_alias_name))
|
|
240
|
+
else:
|
|
241
|
+
if val == -1:
|
|
242
|
+
tensor_map_alias_name.append("None")
|
|
243
|
+
else:
|
|
244
|
+
tensor_map_alias_name.append(alias_name[len(device_matrix) - val - 1])
|
|
245
|
+
new_layout = new_devmat(*tensor_map_alias_name)
|
|
246
|
+
return new_layout
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class CommTensorDataForPP(Cell):
|
|
250
|
+
"""Communicate tensor data for pipeline parallel scenario."""
|
|
251
|
+
|
|
252
|
+
def __init__(self, src_dtensor_info, dst_dtensor_info):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.zeros = P.Zeros()
|
|
255
|
+
|
|
256
|
+
self._current_rank_id = get_rank()
|
|
257
|
+
self._from_dev_num_in_stage = len(src_dtensor_info.layout.to_dict()["rank_list"])
|
|
258
|
+
self._from_rank_id = src_dtensor_info.layout.to_dict()["rank_list"]
|
|
259
|
+
self._current_rank_has_data = self._current_rank_id in src_dtensor_info.layout.to_dict()["rank_list"]
|
|
260
|
+
self._diff_rank_id = [
|
|
261
|
+
rank_id for rank_id in dst_dtensor_info.layout.to_dict()["rank_list"] if rank_id not in self._from_rank_id]
|
|
262
|
+
self._group, self._root_idx = self._create_all_reduce_group()
|
|
263
|
+
|
|
264
|
+
def comm_data(self, comm_data):
|
|
265
|
+
"""communicate data"""
|
|
266
|
+
from mindspore import mint
|
|
267
|
+
comm_handle = mint.distributed.broadcast(comm_data, self._root_idx, self._group, async_op=False)
|
|
268
|
+
return comm_handle
|
|
269
|
+
|
|
270
|
+
def _create_all_reduce_group(self):
|
|
271
|
+
"""create all reduce group"""
|
|
272
|
+
global ALLREDUCE_GROUP_LIST
|
|
273
|
+
current_rank_stage_id = self._current_rank_id // self._from_dev_num_in_stage
|
|
274
|
+
end_stage = self._from_dev_num_in_stage * (current_rank_stage_id + 1)
|
|
275
|
+
rank_pos_in_stage = [rank_id for rank_id in range(self._from_dev_num_in_stage * current_rank_stage_id,
|
|
276
|
+
end_stage)].index(self._current_rank_id)
|
|
277
|
+
root_idx = self._from_rank_id[rank_pos_in_stage]
|
|
278
|
+
all_reduce_rank_list = [self._from_rank_id[rank_pos_in_stage]]
|
|
279
|
+
while rank_pos_in_stage < len(self._diff_rank_id):
|
|
280
|
+
all_reduce_rank_list.append(self._diff_rank_id[rank_pos_in_stage])
|
|
281
|
+
rank_pos_in_stage += self._from_dev_num_in_stage
|
|
282
|
+
all_reduce_rank_list.sort()
|
|
283
|
+
str_rank_list = '-'.join([str(rank) for rank in all_reduce_rank_list])
|
|
284
|
+
all_reduce_group = f"pp_allreduce_group-{str_rank_list}"
|
|
285
|
+
if all_reduce_group in ALLREDUCE_GROUP_LIST:
|
|
286
|
+
return all_reduce_group, root_idx
|
|
287
|
+
ALLREDUCE_GROUP_LIST.append(all_reduce_group)
|
|
288
|
+
create_group(all_reduce_group, all_reduce_rank_list)
|
|
289
|
+
logger.debug(f"Create group {all_reduce_group} for tensor data communication.")
|
|
290
|
+
return all_reduce_group, root_idx
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class RedistributionCell(Cell):
|
|
294
|
+
"""Redistribute src_layout to dst_layout"""
|
|
295
|
+
|
|
296
|
+
def __init__(self, src_layout, dst_layout):
|
|
297
|
+
super().__init__()
|
|
298
|
+
if src_layout is None or dst_layout is None:
|
|
299
|
+
raise ValueError("src_layout and dst_layout should not be None.")
|
|
300
|
+
self._total_dev_num = get_group_size()
|
|
301
|
+
src_layout = _insert_virtual_pp_dim(src_layout)
|
|
302
|
+
dst_layout = _insert_virtual_pp_dim(dst_layout)
|
|
303
|
+
self.src_identity = P.Identity().shard(in_strategy=(src_layout,), out_strategy=(src_layout,))
|
|
304
|
+
self.src_identity.add_prim_attr("self_define_shard", True)
|
|
305
|
+
self.dst_identity = P.Identity().shard(in_strategy=(dst_layout,), out_strategy=(dst_layout,))
|
|
306
|
+
self.dst_identity.add_prim_attr("self_define_shard", True)
|
|
307
|
+
|
|
308
|
+
def construct(self, input_tensor):
|
|
309
|
+
"""run redistribution"""
|
|
310
|
+
src_tensor = self.src_identity(input_tensor)
|
|
311
|
+
dst_tensor = self.dst_identity(src_tensor)
|
|
312
|
+
return dst_tensor
|