mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -18,34 +18,46 @@ from __future__ import absolute_import
|
|
|
18
18
|
import os
|
|
19
19
|
import glob
|
|
20
20
|
import copy
|
|
21
|
+
from multiprocessing import Pool
|
|
21
22
|
from collections import defaultdict
|
|
22
23
|
import numpy as np
|
|
23
24
|
import mindspore as ms
|
|
25
|
+
from mindspore import log as logger
|
|
26
|
+
from mindspore import _checkparam as Validator
|
|
24
27
|
from mindspore.common import dtype as mstype
|
|
25
|
-
from mindspore.
|
|
28
|
+
from mindspore.common.parameter import Parameter
|
|
29
|
+
from mindspore.common.tensor import Tensor
|
|
30
|
+
from mindspore.communication.management import get_rank, get_group_size
|
|
31
|
+
from mindspore.parallel._tensor import _load_tensor, _reshape_param_data, _reshape_param_data_with_weight, \
|
|
32
|
+
_get_tensor_slice_index, _get_tensor_strategy
|
|
33
|
+
from mindspore.parallel._utils import _is_in_auto_parallel_mode, _get_pipeline_stages, _infer_rank_list, \
|
|
34
|
+
_remove_repeated_slices, _get_auto_parallel_net
|
|
26
35
|
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
|
|
27
|
-
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
|
|
36
|
+
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, _build_searched_strategy, \
|
|
28
37
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
29
|
-
_merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src
|
|
30
|
-
|
|
38
|
+
_merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src, _convert_to_list, \
|
|
39
|
+
_check_checkpoint_file, _check_predict_strategy, _gather_tasks_load_dis, _get_param_list_when_first_dim_sharded, \
|
|
40
|
+
_convert_to_layout, _restore_group_info_list
|
|
31
41
|
from mindspore._c_expression import AutoParallelContext
|
|
42
|
+
from mindspore.parallel.transform_safetensors import _transform_safetensors, _collect_safetensor_files, \
|
|
43
|
+
_load_parallel_checkpoint
|
|
32
44
|
|
|
33
45
|
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
|
|
34
|
-
"transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"
|
|
46
|
+
"transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints",
|
|
47
|
+
"load_distributed_checkpoint", "merge_sliced_parameter", "restore_group_info_list",
|
|
48
|
+
"build_searched_strategy"]
|
|
35
49
|
|
|
36
50
|
|
|
37
51
|
def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
|
|
38
52
|
"""
|
|
39
|
-
|
|
40
|
-
For more details about converting distributed Checkpoint, please refer to
|
|
41
|
-
`Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
|
|
53
|
+
Aggregate the sharding strategy files of all pipeline parallel subgraphs to the destination file.
|
|
42
54
|
|
|
43
55
|
Note:
|
|
44
56
|
Strategy file of each pipeline stage should be included in src_strategy_dirs.
|
|
45
57
|
|
|
46
58
|
Args:
|
|
47
59
|
src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
|
|
48
|
-
|
|
60
|
+
:func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file`.
|
|
49
61
|
dst_strategy_file (str): The file merged strategy to save.
|
|
50
62
|
|
|
51
63
|
Raises:
|
|
@@ -54,7 +66,7 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
|
|
|
54
66
|
Examples:
|
|
55
67
|
>>> import mindspore as ms
|
|
56
68
|
>>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
|
|
57
|
-
>>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
|
|
69
|
+
>>> ms.parallel.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
|
|
58
70
|
|
|
59
71
|
"""
|
|
60
72
|
dst_strategy_dir, _ = os.path.split(dst_strategy_file)
|
|
@@ -73,11 +85,211 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
|
|
|
73
85
|
_merge_json_strategy(src_strategy_files_json, dst_strategy_file)
|
|
74
86
|
|
|
75
87
|
|
|
88
|
+
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
89
|
+
"""
|
|
90
|
+
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
94
|
+
strategy (Optional[dict], optional): Parameter slice strategy, whose key is parameter name and
|
|
95
|
+
value is slice strategy of this parameter. If strategy is None, just merge
|
|
96
|
+
parameter slices in 0 axis order. Default: ``None``.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Parameter, the merged parameter which has the whole data.
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: Failed to merge.
|
|
103
|
+
TypeError: The sliced_parameters is incorrect or strategy is not dict.
|
|
104
|
+
KeyError: The parameter name is not in keys of strategy.
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
>>> import numpy as np
|
|
108
|
+
>>> import mindspore as ms
|
|
109
|
+
>>> from mindspore import Tensor, Parameter
|
|
110
|
+
>>>
|
|
111
|
+
>>> sliced_parameters = [
|
|
112
|
+
... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
|
|
113
|
+
... "network.embedding_table"),
|
|
114
|
+
... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
|
|
115
|
+
... "network.embedding_table"),
|
|
116
|
+
... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
|
|
117
|
+
... "network.embedding_table"),
|
|
118
|
+
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
|
119
|
+
... "network.embedding_table")]
|
|
120
|
+
>>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
|
|
121
|
+
>>> print(merged_parameter)
|
|
122
|
+
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
|
|
123
|
+
"""
|
|
124
|
+
if not isinstance(sliced_parameters, list):
|
|
125
|
+
raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
|
|
126
|
+
f"but got {type(sliced_parameters)}.")
|
|
127
|
+
|
|
128
|
+
if not sliced_parameters:
|
|
129
|
+
raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
|
|
130
|
+
|
|
131
|
+
if strategy and not isinstance(strategy, dict):
|
|
132
|
+
raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
|
|
133
|
+
f"but got {type(strategy)}.")
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
parameter_name = sliced_parameters[0].name
|
|
137
|
+
parameter_shape = sliced_parameters[0].data.shape
|
|
138
|
+
parameter_shape_length = len(parameter_shape)
|
|
139
|
+
except BaseException as e:
|
|
140
|
+
raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
|
|
141
|
+
f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
|
|
142
|
+
|
|
143
|
+
is_even = True
|
|
144
|
+
for index, parameter in enumerate(sliced_parameters):
|
|
145
|
+
if not isinstance(parameter, Parameter):
|
|
146
|
+
raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
|
|
147
|
+
f"but got {type(parameter)} at index {index}.")
|
|
148
|
+
|
|
149
|
+
if parameter.name != parameter_name \
|
|
150
|
+
or len(parameter.data.shape) != parameter_shape_length \
|
|
151
|
+
or parameter.data.shape[1:] != parameter_shape[1:]:
|
|
152
|
+
raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
|
|
153
|
+
f" have the same name, dimension length and shape except 0 axis. The name, dimension "
|
|
154
|
+
f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
|
|
155
|
+
f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
|
|
156
|
+
f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
|
|
157
|
+
f"at index {index}.")
|
|
158
|
+
|
|
159
|
+
if parameter.data.shape != parameter_shape:
|
|
160
|
+
is_even = False
|
|
161
|
+
|
|
162
|
+
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
|
163
|
+
requires_grad = sliced_parameters[0].requires_grad
|
|
164
|
+
sliced_data = []
|
|
165
|
+
for parameter in sliced_parameters:
|
|
166
|
+
if parameter.data.dtype == mstype.bfloat16:
|
|
167
|
+
from mindspore.ops import Cast
|
|
168
|
+
cpu_cast = Cast().set_device("CPU")
|
|
169
|
+
sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
|
|
170
|
+
else:
|
|
171
|
+
sliced_data.append(parameter.data.asnumpy())
|
|
172
|
+
|
|
173
|
+
if not strategy:
|
|
174
|
+
merged_tensor = Tensor(np.concatenate(sliced_data))
|
|
175
|
+
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
if parameter_name not in strategy.keys():
|
|
179
|
+
raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
|
|
180
|
+
f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
|
|
181
|
+
merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
|
|
182
|
+
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
183
|
+
|
|
184
|
+
return merged_parameter
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
188
|
+
"""Merge sliced parameter and split it according to the predict strategy."""
|
|
189
|
+
merged_param = merge_sliced_parameter(sliced_params, train_strategy)
|
|
190
|
+
if not predict_strategy:
|
|
191
|
+
return merged_param
|
|
192
|
+
param_name = merged_param.name
|
|
193
|
+
tensor_layout = predict_strategy[param_name]
|
|
194
|
+
rank = get_rank()
|
|
195
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
196
|
+
requires_grad = merged_param.requires_grad
|
|
197
|
+
layerwise_parallel = merged_param.layerwise_parallel
|
|
198
|
+
if merged_param.data.dtype == mstype.bfloat16:
|
|
199
|
+
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
200
|
+
else:
|
|
201
|
+
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
202
|
+
return split_param
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
206
|
+
"""
|
|
207
|
+
Merge data slices to one tensor with whole data when strategy is not None.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
|
|
211
|
+
parameter_name (str): Name of parameter.
|
|
212
|
+
strategy (dict): Parameter slice strategy.
|
|
213
|
+
is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Tensor, the merged Tensor which has the whole data.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
ValueError: Failed to merge.
|
|
220
|
+
"""
|
|
221
|
+
layout = strategy.get(parameter_name)
|
|
222
|
+
try:
|
|
223
|
+
dev_mat = list(layout.dev_matrix[0].dim)
|
|
224
|
+
tensor_map = list(layout.tensor_map[0].dim)
|
|
225
|
+
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
226
|
+
field_size = int(layout.field)
|
|
227
|
+
except BaseException as e:
|
|
228
|
+
raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
|
|
229
|
+
f", please make sure that 'strategy' is correct.") from e
|
|
230
|
+
|
|
231
|
+
device_count = 1
|
|
232
|
+
for dim in dev_mat:
|
|
233
|
+
device_count *= dim
|
|
234
|
+
|
|
235
|
+
if len(sliced_data) != device_count:
|
|
236
|
+
raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
|
|
237
|
+
f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
|
|
238
|
+
f"device_count is {device_count}.")
|
|
239
|
+
|
|
240
|
+
if not param_split_shape:
|
|
241
|
+
if not is_even:
|
|
242
|
+
raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
|
|
243
|
+
"should be the same when slice manner is even.")
|
|
244
|
+
|
|
245
|
+
all_gather_tensor = Tensor(np.concatenate(sliced_data))
|
|
246
|
+
|
|
247
|
+
if field_size > 0:
|
|
248
|
+
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
|
|
249
|
+
else:
|
|
250
|
+
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
|
|
251
|
+
|
|
252
|
+
else:
|
|
253
|
+
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
254
|
+
|
|
255
|
+
slice_count = 1
|
|
256
|
+
for dim in tensor_strategy:
|
|
257
|
+
slice_count *= dim
|
|
258
|
+
|
|
259
|
+
if len(param_split_shape) != slice_count:
|
|
260
|
+
raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
|
|
261
|
+
f"{slice_count}, but got {len(param_split_shape)}.")
|
|
262
|
+
|
|
263
|
+
tensor_slices_new = list(range(slice_count))
|
|
264
|
+
tensor_slices = sliced_data
|
|
265
|
+
for i in range(device_count):
|
|
266
|
+
slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
|
|
267
|
+
if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
|
|
268
|
+
raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
|
|
269
|
+
f"{param_split_shape[slice_index]} in 0 axis, but got "
|
|
270
|
+
f"{tensor_slices[i].shape[0]}.")
|
|
271
|
+
tensor_slices_new[slice_index] = np.array(tensor_slices[i])
|
|
272
|
+
|
|
273
|
+
dim_len = len(tensor_strategy)
|
|
274
|
+
for i in range(dim_len):
|
|
275
|
+
ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
|
|
276
|
+
tensor_slices_new_inner = []
|
|
277
|
+
for j in range(ele_count):
|
|
278
|
+
new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
|
|
279
|
+
for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
|
|
280
|
+
(j + 1) * tensor_strategy[dim_len - 1 - i]):
|
|
281
|
+
new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
|
|
282
|
+
tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
|
|
283
|
+
tensor_slices_new = tensor_slices_new_inner
|
|
284
|
+
merged_tensor = Tensor(tensor_slices_new[0])
|
|
285
|
+
|
|
286
|
+
return merged_tensor
|
|
287
|
+
|
|
288
|
+
|
|
76
289
|
def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
|
|
77
290
|
"""
|
|
78
291
|
List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the
|
|
79
|
-
distributed checkpoint conversion.
|
|
80
|
-
`Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
|
|
292
|
+
distributed checkpoint conversion.
|
|
81
293
|
|
|
82
294
|
Args:
|
|
83
295
|
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
|
|
@@ -102,7 +314,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
|
|
|
102
314
|
Examples:
|
|
103
315
|
>>> import mindspore as ms
|
|
104
316
|
>>> rank_id = 0
|
|
105
|
-
>>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
|
317
|
+
>>> rank_list = ms.parallel.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
|
106
318
|
>>> checkpoint_files_map = {}
|
|
107
319
|
>>> for rank in rank_list:
|
|
108
320
|
... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
|
|
@@ -141,8 +353,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|
|
141
353
|
src_strategy_file=None, dst_strategy_file=None):
|
|
142
354
|
"""
|
|
143
355
|
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
|
|
144
|
-
for a network.
|
|
145
|
-
`Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
|
|
356
|
+
for a network.
|
|
146
357
|
|
|
147
358
|
Args:
|
|
148
359
|
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
|
|
@@ -150,11 +361,11 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|
|
150
361
|
the checkpoint file name.
|
|
151
362
|
save_checkpoint_file_name (str): The file name to save the converted checkpoint.
|
|
152
363
|
src_strategy_file (str): Name of source sharding strategy file which saved by
|
|
153
|
-
|
|
364
|
+
`mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
|
|
154
365
|
when the `src_strategy_file` is None, it means that the source sharding strategy is
|
|
155
366
|
without any sharing for each parameter. Default: ``None``.
|
|
156
367
|
dst_strategy_file (str): Name of destination sharding strategy file which saved by
|
|
157
|
-
|
|
368
|
+
`mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
|
|
158
369
|
when the `dst_strategy_file` is ``None``,
|
|
159
370
|
it means that the destination sharding strategy
|
|
160
371
|
is without any sharing for each parameter. Default: ``None``.
|
|
@@ -362,8 +573,6 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|
|
362
573
|
dst_strategy_file=None, process_num=1, output_format="ckpt"):
|
|
363
574
|
"""
|
|
364
575
|
Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
|
|
365
|
-
For more details about converting distributed Checkpoint, please refer to
|
|
366
|
-
`Model Transformation <https://www.mindspore.cn/docs/en/master/model_train/parallel/model_transformation.html>`_.
|
|
367
576
|
|
|
368
577
|
Note:
|
|
369
578
|
The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
|
|
@@ -387,7 +596,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|
|
387
596
|
is without any sharing for each parameter. Default:None.
|
|
388
597
|
process_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
389
598
|
output_format (str, optional): Control the format of the output checkpoint after conversion.
|
|
390
|
-
It can be set to either "ckpt" or "safetensors"
|
|
599
|
+
It can be set to either ``"ckpt"`` or ``"safetensors"``. Default: ``"ckpt"``.
|
|
391
600
|
|
|
392
601
|
Raises:
|
|
393
602
|
ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
|
|
@@ -473,18 +682,21 @@ def _sync_params(name, param, layout):
|
|
|
473
682
|
shape=param.shape,
|
|
474
683
|
dtype=param.dtype)(param))
|
|
475
684
|
|
|
476
|
-
|
|
685
|
+
# pylint: disable=W0212
|
|
477
686
|
def sync_pipeline_shared_parameters(net):
|
|
478
|
-
"""
|
|
479
|
-
|
|
687
|
+
"""Synchronization of shared weights between stages for pipeline parallel inference scenarios.
|
|
688
|
+
For example, `embedding table` is
|
|
480
689
|
shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
|
|
481
690
|
to perform synchronization after `embedding table` changes.
|
|
482
691
|
|
|
483
692
|
Note:
|
|
484
|
-
The network should be compiled before
|
|
693
|
+
The network should be compiled before shared parameters are synchronized in the pipeline parallel stage.
|
|
485
694
|
|
|
486
695
|
Args:
|
|
487
|
-
net (
|
|
696
|
+
net (Cell): the inference network.
|
|
697
|
+
|
|
698
|
+
Raises:
|
|
699
|
+
TypeError: `net` is not in Cell type.
|
|
488
700
|
|
|
489
701
|
Supported Platforms:
|
|
490
702
|
``Ascend``
|
|
@@ -494,12 +706,13 @@ def sync_pipeline_shared_parameters(net):
|
|
|
494
706
|
Before running the following examples, you need to configure the communication environment variables.
|
|
495
707
|
|
|
496
708
|
For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
497
|
-
Startup <https://www.mindspore.cn/
|
|
709
|
+
Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
|
|
498
710
|
|
|
499
711
|
>>> import numpy as np
|
|
500
712
|
>>> import mindspore as ms
|
|
501
713
|
>>> import mindspore.communication.management as D
|
|
502
714
|
>>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
|
|
715
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
503
716
|
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
504
717
|
>>> class Embedding(nn.Cell):
|
|
505
718
|
... def __init__(self, shape):
|
|
@@ -547,14 +760,16 @@ def sync_pipeline_shared_parameters(net):
|
|
|
547
760
|
... ret = self.concat(ret)
|
|
548
761
|
... return ret
|
|
549
762
|
>>> D.init()
|
|
550
|
-
>>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
|
|
551
763
|
>>> net = Network()
|
|
552
764
|
>>> net = PipelineCellInference(net, 2)
|
|
553
765
|
>>> net.set_train(False)
|
|
554
766
|
>>> x = Tensor(np.ones((2, 4)), ms.float32)
|
|
555
767
|
>>> net.compile(x)
|
|
556
|
-
>>>
|
|
557
|
-
>>>
|
|
768
|
+
>>> pp_net = AutoParallel(net, parallel_mode="semi_auto")
|
|
769
|
+
>>> pp_net.full_batch = True
|
|
770
|
+
>>> pp_net.pipeline(stages=2, scheduler="1f1b")
|
|
771
|
+
>>> ms.parallel.sync_pipeline_shared_parameters(pp_net)
|
|
772
|
+
>>> print(pp_net.network.network.word_embedding.w.asnumpy())
|
|
558
773
|
[[1. 1. 1. 1.]
|
|
559
774
|
[1. 1. 1. 1.]
|
|
560
775
|
[1. 1. 1. 1.]
|
|
@@ -567,18 +782,25 @@ def sync_pipeline_shared_parameters(net):
|
|
|
567
782
|
"but got {}.".format(type(net)))
|
|
568
783
|
raise TypeError(msg)
|
|
569
784
|
|
|
570
|
-
|
|
785
|
+
parallel_net = _get_auto_parallel_net(net)
|
|
786
|
+
pipeline_stages = 1
|
|
787
|
+
if type(parallel_net).__name__ != 'AutoParallel':
|
|
788
|
+
pipeline_stages = _get_pipeline_stages()
|
|
789
|
+
else:
|
|
790
|
+
pipeline_stages = parallel_net._pipeline_stages
|
|
791
|
+
if pipeline_stages < 2:
|
|
571
792
|
return
|
|
572
793
|
|
|
573
794
|
layout_dict = net.parameter_layout_dict
|
|
574
|
-
if _is_in_auto_parallel_mode() and not layout_dict:
|
|
795
|
+
if (_is_in_auto_parallel_mode() or (type(parallel_net).__name__ == 'AutoParallel')) and not layout_dict:
|
|
575
796
|
from mindspore.common.api import _get_parameter_layout
|
|
576
797
|
layout_dict = _get_parameter_layout()
|
|
577
798
|
|
|
578
799
|
# switch to standalone mode
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
800
|
+
if type(parallel_net).__name__ != 'AutoParallel':
|
|
801
|
+
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
|
|
802
|
+
full_batch = ms.context.get_auto_parallel_context("full_batch")
|
|
803
|
+
ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
|
|
582
804
|
|
|
583
805
|
# synchronize shared parameter
|
|
584
806
|
for name, param in net.parameters_and_names():
|
|
@@ -586,7 +808,8 @@ def sync_pipeline_shared_parameters(net):
|
|
|
586
808
|
_sync_params(name, param, layout_dict[name])
|
|
587
809
|
|
|
588
810
|
# restore parallel context
|
|
589
|
-
|
|
811
|
+
if type(parallel_net).__name__ != 'AutoParallel':
|
|
812
|
+
ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
|
|
590
813
|
|
|
591
814
|
|
|
592
815
|
def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None,
|
|
@@ -636,6 +859,9 @@ def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filte
|
|
|
636
859
|
ValueError: Checkpoint file's format is incorrect.
|
|
637
860
|
ValueError: Parameter's dict is None after load checkpoint file.
|
|
638
861
|
TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
|
862
|
+
|
|
863
|
+
Supported Platforms:
|
|
864
|
+
``Ascend``
|
|
639
865
|
"""
|
|
640
866
|
if not isinstance(ckpt_file_dir, str):
|
|
641
867
|
raise TypeError("The ckpt_file_dir should be a str.")
|
|
@@ -656,8 +882,10 @@ def set_op_strategy_config(mode="SAVE", path=""):
|
|
|
656
882
|
Set strategy json configuration when using sharding propagation.
|
|
657
883
|
|
|
658
884
|
.. warning::
|
|
659
|
-
This is an experimental interface, may be changed or canceled in the future
|
|
660
|
-
|
|
885
|
+
- This is an experimental interface, may be changed or canceled in the future, please use the api
|
|
886
|
+
:func:`mindspore.parallel.auto_parallel.AutoParallel.load_operator_strategy_file` or
|
|
887
|
+
:func:`mindspore.parallel.auto_parallel.AutoParallel.save_operator_strategy_file` instead;
|
|
888
|
+
- This interface currently doesn't support saving or loading strategies using layout.
|
|
661
889
|
|
|
662
890
|
Note:
|
|
663
891
|
- It only works when `parallel_mode=ParallelMode.AUTO_PARALLEL` and `search_mode='sharding_propagation'`.
|
|
@@ -692,3 +920,396 @@ def set_op_strategy_config(mode="SAVE", path=""):
|
|
|
692
920
|
AutoParallelContext.get_instance().set_ops_strategy_json_config(mode, path, "all")
|
|
693
921
|
else:
|
|
694
922
|
raise KeyError("Type must be 'SAVE' or 'LOAD'")
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def build_searched_strategy(strategy_filename):
|
|
926
|
+
"""
|
|
927
|
+
Extract the sharding strategy for each parameter in the network from the strategy file
|
|
928
|
+
for distributed inference scenarios.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
strategy_filename (str): Name of strategy file.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
Dict, whose key is parameter name and value is slice strategy of this parameter.
|
|
935
|
+
|
|
936
|
+
Raises:
|
|
937
|
+
ValueError: Strategy file is incorrect.
|
|
938
|
+
TypeError: `strategy_filename` is not a string.
|
|
939
|
+
|
|
940
|
+
Supported Platforms:
|
|
941
|
+
``Ascend``
|
|
942
|
+
|
|
943
|
+
Examples:
|
|
944
|
+
>>> from mindspore.parallel import build_searched_strategy
|
|
945
|
+
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
|
|
946
|
+
"""
|
|
947
|
+
return _build_searched_strategy(strategy_filename)
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
# disable pylint too broad Exception
|
|
951
|
+
# pylint: disable=W0212
|
|
952
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
953
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
954
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
|
|
955
|
+
output_format='safetensors', name_map=None, max_process_num=64,
|
|
956
|
+
return_param_dict=False):
|
|
957
|
+
"""
|
|
958
|
+
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
959
|
+
|
|
960
|
+
Note:
|
|
961
|
+
`output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
|
|
965
|
+
can be left blank or passed as None, and the interface will execute save mode.
|
|
966
|
+
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
967
|
+
predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
|
|
968
|
+
when setting predict_strategy as None. Default: ``None`` .
|
|
969
|
+
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
970
|
+
When train_strategy_filename is None, the training strategy file will be
|
|
971
|
+
obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
|
|
972
|
+
Therefore, the training strategy file needs to be specified
|
|
973
|
+
in at least one of them. Default: ``None`` .
|
|
974
|
+
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
975
|
+
into net when parameter name's suffix in checkpoint file is the same as the
|
|
976
|
+
parameter in the network. When the types are inconsistent, perform type conversion
|
|
977
|
+
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
978
|
+
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
979
|
+
is not required. Default: ``None`` .
|
|
980
|
+
dec_mode (str): Specifies the decryption
|
|
981
|
+
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
982
|
+
This parameter is valid only when dec_key is not set to ``None`` .
|
|
983
|
+
Default: ``'AES-GCM'`` .
|
|
984
|
+
format (str): Input weight format to be loaded into the network.
|
|
985
|
+
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
986
|
+
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
987
|
+
Default: ``None`` .
|
|
988
|
+
dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
|
|
989
|
+
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
990
|
+
globally by initializing the network; In save mode, save the file according to the input
|
|
991
|
+
sequence number. If it is not input, save the entire file.
|
|
992
|
+
output_format (str, optional): Control the format of the output checkpoint after conversion.
|
|
993
|
+
It can be set to either "ckpt" or "safetensors". Default: "safetensors".
|
|
994
|
+
name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
|
|
995
|
+
dictionary before loading or saving the segmented weights into the network. Default: None.
|
|
996
|
+
max_process_num (int): Maximum number of processes. Default: 64.
|
|
997
|
+
return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
|
|
998
|
+
|
|
999
|
+
Raises:
|
|
1000
|
+
TypeError: The type of inputs do not match the requirements.
|
|
1001
|
+
ValueError: Failed to load checkpoint into net.
|
|
1002
|
+
|
|
1003
|
+
Supported Platforms:
|
|
1004
|
+
``Ascend``
|
|
1005
|
+
|
|
1006
|
+
Examples:
|
|
1007
|
+
.. note::
|
|
1008
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
1009
|
+
|
|
1010
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
1011
|
+
Please see the `rank table startup
|
|
1012
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
|
|
1013
|
+
for more details.
|
|
1014
|
+
|
|
1015
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
1016
|
+
Startup <https://www.mindspore.cn/tutorials/en/master/parallel/dynamic_cluster.html>`_ .
|
|
1017
|
+
|
|
1018
|
+
>>> import os
|
|
1019
|
+
>>> import numpy as np
|
|
1020
|
+
>>> import mindspore as ms
|
|
1021
|
+
>>> import mindspore.dataset as ds
|
|
1022
|
+
>>> from mindspore import nn, ops, train
|
|
1023
|
+
>>> from mindspore.communication import init
|
|
1024
|
+
>>> from mindspore.parallel import load_distributed_checkpoint
|
|
1025
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
1026
|
+
>>> from mindspore.nn.utils import no_init_parameters
|
|
1027
|
+
>>> from mindspore.common.initializer import initializer, One
|
|
1028
|
+
>>> from mindspore.communication.management import get_group_size
|
|
1029
|
+
>>>
|
|
1030
|
+
>>> step_per_epoch = 4
|
|
1031
|
+
>>> device_num = get_group_size()
|
|
1032
|
+
>>>
|
|
1033
|
+
>>> # Define the network structure.
|
|
1034
|
+
>>> class Net(nn.Cell):
|
|
1035
|
+
... def __init__(self, matmul_size, strategy=None):
|
|
1036
|
+
... super().__init__()
|
|
1037
|
+
... self.matmul_weight = ms.Parameter(initializer(One(), matmul_size, ms.float32))
|
|
1038
|
+
... self.matmul = ops.MatMul()
|
|
1039
|
+
... self.neg = ops.Neg()
|
|
1040
|
+
... if strategy is not None:
|
|
1041
|
+
... self.matmul.shard(strategy)
|
|
1042
|
+
...
|
|
1043
|
+
... def construct(self, inputs):
|
|
1044
|
+
... x = self.matmul(inputs, self.matmul_weight)
|
|
1045
|
+
... x = self.neg(x)
|
|
1046
|
+
... return x
|
|
1047
|
+
>>>
|
|
1048
|
+
>>> # Create dataset.
|
|
1049
|
+
>>> def get_dataset(*inputs):
|
|
1050
|
+
... def generate():
|
|
1051
|
+
... for _ in range(step_per_epoch):
|
|
1052
|
+
... yield inputs
|
|
1053
|
+
... return generate
|
|
1054
|
+
>>>
|
|
1055
|
+
>>> # Train network and save distributed checkpoint.
|
|
1056
|
+
>>> def train_net():
|
|
1057
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
1058
|
+
... init()
|
|
1059
|
+
... np.random.seed(1)
|
|
1060
|
+
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
1061
|
+
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
1062
|
+
... fake_dataset = get_dataset(input_data, label_data)
|
|
1063
|
+
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
1064
|
+
...
|
|
1065
|
+
... # Set parallel strategy.
|
|
1066
|
+
... strategy = ((1, 4), (4, 1))
|
|
1067
|
+
... with no_init_parameters():
|
|
1068
|
+
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
1069
|
+
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
1070
|
+
...
|
|
1071
|
+
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
1072
|
+
... network = AutoParallel(network, parallel_mode="semi_auto")
|
|
1073
|
+
... network.save_param_strategy_file(file_path="./train_strategy.ckpt")
|
|
1074
|
+
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
1075
|
+
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=True)
|
|
1076
|
+
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
1077
|
+
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
1078
|
+
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
1079
|
+
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
1080
|
+
>>>
|
|
1081
|
+
>>> # Load distributed checkpoint and test.
|
|
1082
|
+
>>> def load_model():
|
|
1083
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
1084
|
+
... init()
|
|
1085
|
+
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
1086
|
+
... with no_init_parameters():
|
|
1087
|
+
... network = Net(matmul_size=(96, 16))
|
|
1088
|
+
... network = AutoParallel(network, parallel_mode="semi_auto")
|
|
1089
|
+
... network.dataset_strategy(config="full_batch")
|
|
1090
|
+
... train_strategy_file = "./train_strategy.ckpt"
|
|
1091
|
+
... network.save_param_strategy_file(file_path=train_strategy_file)
|
|
1092
|
+
... model = ms.Model(network)
|
|
1093
|
+
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
1094
|
+
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
1095
|
+
... load_distributed_checkpoint(network, ckpt_file_list, predict_layout, None)
|
|
1096
|
+
... predict_result = model.predict(predict_data)
|
|
1097
|
+
... print(predict_result)
|
|
1098
|
+
>>>
|
|
1099
|
+
>>> train_net()
|
|
1100
|
+
>>> load_model()
|
|
1101
|
+
[[-9.62929535e+00, -9.76258755e+00, -9.70192051e+00 ... -9.67151260e+00, -9.71998310e+00, -9.64571190e+00],
|
|
1102
|
+
[-4.63218540e-01, -4.07317460e-01, -3.78161550e-01 ... -3.95918339e-01, -2.87363172e-01, -3.48693460e-01],
|
|
1103
|
+
...
|
|
1104
|
+
[-4.28075647e+00, -4.36630344e+00, -4.25664043e+00 ... -4.32012939e+00, -4.30337954e+00, -4.27571440e+00]]
|
|
1105
|
+
"""
|
|
1106
|
+
if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
|
|
1107
|
+
raise ValueError(
|
|
1108
|
+
f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
|
|
1109
|
+
f"must be 'ckpt' or 'safetensors', but got {format}.")
|
|
1110
|
+
|
|
1111
|
+
if format == 'safetensors':
|
|
1112
|
+
if unified_safetensors_dir is None:
|
|
1113
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
1114
|
+
f"when format is 'safetensors'.")
|
|
1115
|
+
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
1116
|
+
for param in unsupport_param:
|
|
1117
|
+
if param is not None:
|
|
1118
|
+
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
1119
|
+
f"when format is 'safetensors'.")
|
|
1120
|
+
if strict_load or dec_mode != 'AES-GCM':
|
|
1121
|
+
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
1122
|
+
f"when format is 'safetensors'.")
|
|
1123
|
+
if network is not None:
|
|
1124
|
+
try:
|
|
1125
|
+
rank_id = get_rank()
|
|
1126
|
+
except RuntimeError:
|
|
1127
|
+
rank_id = 0
|
|
1128
|
+
logger.warning(f"Get rank failed, default loading weight for rank 0.")
|
|
1129
|
+
param_dict = _load_parallel_checkpoint(
|
|
1130
|
+
(unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
|
|
1131
|
+
return_param_dict))
|
|
1132
|
+
return param_dict
|
|
1133
|
+
if dst_safetensors_dir is None:
|
|
1134
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
1135
|
+
f"when network is None.")
|
|
1136
|
+
if rank_id is not None:
|
|
1137
|
+
_load_parallel_checkpoint(
|
|
1138
|
+
(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
1139
|
+
rank_id, output_format, name_map, return_param_dict))
|
|
1140
|
+
else:
|
|
1141
|
+
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
1142
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
1143
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
1144
|
+
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
1145
|
+
tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
1146
|
+
dst_device_num, output_format, name_map, return_param_dict)
|
|
1147
|
+
with Pool(processes=max_process_num) as pool:
|
|
1148
|
+
list(pool.imap(_load_parallel_checkpoint, tasks))
|
|
1149
|
+
return True
|
|
1150
|
+
|
|
1151
|
+
network = Validator.check_isinstance("network", network, ms.nn.Cell)
|
|
1152
|
+
_check_checkpoint_file(checkpoint_filenames)
|
|
1153
|
+
_check_predict_strategy(predict_strategy)
|
|
1154
|
+
|
|
1155
|
+
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1156
|
+
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1157
|
+
|
|
1158
|
+
if train_strategy_filename is None:
|
|
1159
|
+
parallel_net = _get_auto_parallel_net(network)
|
|
1160
|
+
if parallel_net.__class__.__name__ == "AutoParallel":
|
|
1161
|
+
train_strategy_filename = parallel_net._save_strategy_file_path
|
|
1162
|
+
else:
|
|
1163
|
+
train_strategy_filename = ms.context.get_auto_parallel_context("strategy_ckpt_load_file")
|
|
1164
|
+
|
|
1165
|
+
_train_strategy = build_searched_strategy(train_strategy_filename)
|
|
1166
|
+
train_strategy = _convert_to_list(_train_strategy)
|
|
1167
|
+
|
|
1168
|
+
train_dev_count = 1
|
|
1169
|
+
ckpt_file_len = len(checkpoint_filenames)
|
|
1170
|
+
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
|
|
1171
|
+
train_dev_count *= dim
|
|
1172
|
+
if train_dev_count != ckpt_file_len:
|
|
1173
|
+
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
|
|
1174
|
+
f"equal to the device count of training process. "
|
|
1175
|
+
f"But got the length of 'checkpoint_filenames'"
|
|
1176
|
+
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
|
|
1177
|
+
rank_list = _infer_rank_list(train_strategy, predict_strategy)
|
|
1178
|
+
|
|
1179
|
+
param_total_dict = defaultdict(dict)
|
|
1180
|
+
for file_index, file_name in enumerate(checkpoint_filenames):
|
|
1181
|
+
ckpt_dict = ms.load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
|
|
1182
|
+
for param_name, param in ckpt_dict.items():
|
|
1183
|
+
param_total_dict[param_name][file_index] = param
|
|
1184
|
+
|
|
1185
|
+
param_dict = {}
|
|
1186
|
+
param_not_in_strategy = []
|
|
1187
|
+
param_not_in_ckpt = []
|
|
1188
|
+
for _, param in network.parameters_and_names():
|
|
1189
|
+
sliced_params = []
|
|
1190
|
+
if param.name not in rank_list.keys():
|
|
1191
|
+
param_not_in_strategy.append(param.name)
|
|
1192
|
+
continue
|
|
1193
|
+
if param.name not in param_total_dict:
|
|
1194
|
+
param_not_in_ckpt.append(param.name)
|
|
1195
|
+
continue
|
|
1196
|
+
|
|
1197
|
+
param_rank = rank_list.get(param.name)[0]
|
|
1198
|
+
skip_merge_split = rank_list.get(param.name)[1]
|
|
1199
|
+
shard_stride = train_strategy.get(param.name)[4]
|
|
1200
|
+
tensor_map = train_strategy.get(param.name)[1]
|
|
1201
|
+
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
1202
|
+
device_arrangement = train_strategy.get(param.name)[0]
|
|
1203
|
+
first_dim_shard_size = 1
|
|
1204
|
+
if first_dim_shard_idx >= 0:
|
|
1205
|
+
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
1206
|
+
if train_strategy.get(param.name)[5]:
|
|
1207
|
+
repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
1208
|
+
else:
|
|
1209
|
+
repeat_size = 0
|
|
1210
|
+
for rank in param_rank:
|
|
1211
|
+
param_total_list = list(range(0, ckpt_file_len))
|
|
1212
|
+
if first_dim_shard_size != 1:
|
|
1213
|
+
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
1214
|
+
if repeat_size > 0:
|
|
1215
|
+
shard_size = shard_stride * train_strategy.get(param.name)[5]
|
|
1216
|
+
rank_index = param_total_list.index(rank)
|
|
1217
|
+
start = rank_index // shard_size * shard_size
|
|
1218
|
+
param_total_list = param_total_list[start:start + shard_size]
|
|
1219
|
+
if shard_stride > 0:
|
|
1220
|
+
param_stride = []
|
|
1221
|
+
# merge pre parameter
|
|
1222
|
+
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
|
|
1223
|
+
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
|
|
1224
|
+
param_index = list(set(param_index))
|
|
1225
|
+
param_index.sort()
|
|
1226
|
+
for rank_num in param_index:
|
|
1227
|
+
if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
|
|
1228
|
+
from mindspore.ops import Cast
|
|
1229
|
+
cpu_cast = Cast().set_device("CPU")
|
|
1230
|
+
param_stride.append(
|
|
1231
|
+
cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
|
|
1232
|
+
else:
|
|
1233
|
+
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
|
|
1234
|
+
|
|
1235
|
+
sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
|
|
1236
|
+
else:
|
|
1237
|
+
sliced_param = param_total_dict[param.name][rank]
|
|
1238
|
+
|
|
1239
|
+
sliced_params.append(sliced_param)
|
|
1240
|
+
if skip_merge_split:
|
|
1241
|
+
split_param = sliced_params[0]
|
|
1242
|
+
else:
|
|
1243
|
+
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
|
|
1244
|
+
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
|
|
1245
|
+
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
|
|
1246
|
+
opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
|
|
1247
|
+
if opt_shard_group:
|
|
1248
|
+
if split_param.data.dtype == mstype.bfloat16:
|
|
1249
|
+
from mindspore.ops import Cast
|
|
1250
|
+
cpu_cast = Cast().set_device("CPU")
|
|
1251
|
+
data = cpu_cast(split_param.data, mstype.float32).asnumpy()
|
|
1252
|
+
else:
|
|
1253
|
+
data = split_param.data.asnumpy()
|
|
1254
|
+
rank = get_rank(opt_shard_group)
|
|
1255
|
+
size = get_group_size(opt_shard_group)
|
|
1256
|
+
try:
|
|
1257
|
+
data_slice = np.split(data, size)[rank]
|
|
1258
|
+
except BaseException as e:
|
|
1259
|
+
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
|
1260
|
+
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
|
1261
|
+
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
|
|
1262
|
+
f" in load distributed checkpoint for {param.name}. Data shape is "
|
|
1263
|
+
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
|
1264
|
+
split_param = Parameter(Tensor(data_slice), param.name,
|
|
1265
|
+
split_param.requires_grad, split_param.layerwise_parallel)
|
|
1266
|
+
param_dict[param.name] = split_param
|
|
1267
|
+
|
|
1268
|
+
if param_not_in_strategy:
|
|
1269
|
+
logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
|
|
1270
|
+
"you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
|
|
1271
|
+
.format(param_not_in_strategy))
|
|
1272
|
+
if param_not_in_ckpt:
|
|
1273
|
+
logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
|
|
1274
|
+
"the checkpoint file, please check whether 'checkpoint_filenames' is correct."
|
|
1275
|
+
.format(param_not_in_ckpt))
|
|
1276
|
+
|
|
1277
|
+
ms.load_param_into_net(network, param_dict, strict_load=strict_load)
|
|
1278
|
+
return True
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
def restore_group_info_list(group_info_file_name):
|
|
1282
|
+
"""
|
|
1283
|
+
Extract rank list information from communication domain files. To save the group info file,
|
|
1284
|
+
please export GROUP_INFO_FIL
|
|
1285
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
1286
|
+
|
|
1287
|
+
Args:
|
|
1288
|
+
group_info_file_name (str): Name of group information file.
|
|
1289
|
+
|
|
1290
|
+
Returns:
|
|
1291
|
+
List, the rank list.
|
|
1292
|
+
|
|
1293
|
+
Raises:
|
|
1294
|
+
ValueError: group information file is incorrect.
|
|
1295
|
+
TypeError: `group_info_file_name` is not str.
|
|
1296
|
+
|
|
1297
|
+
Supported Platforms:
|
|
1298
|
+
``Ascend``
|
|
1299
|
+
|
|
1300
|
+
Examples:
|
|
1301
|
+
>>> import mindspore as ms
|
|
1302
|
+
>>> from mindspore.parallel import restore_group_info_list
|
|
1303
|
+
>>> ms.restore_list = restore_group_info_list("./group_info.pb")
|
|
1304
|
+
"""
|
|
1305
|
+
if not isinstance(group_info_file_name, str):
|
|
1306
|
+
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
|
|
1307
|
+
f"but got {type(group_info_file_name)}.")
|
|
1308
|
+
|
|
1309
|
+
if not os.path.isfile(group_info_file_name):
|
|
1310
|
+
raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
|
|
1311
|
+
|
|
1312
|
+
if os.path.getsize(group_info_file_name) == 0:
|
|
1313
|
+
raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
|
|
1314
|
+
|
|
1315
|
+
return _restore_group_info_list(group_info_file_name)
|