mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -18,17 +18,20 @@ from __future__ import division
|
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
21
|
+
import mindspore.log as logger
|
|
21
22
|
from mindspore import context
|
|
22
23
|
from mindspore.nn.cell import Cell
|
|
23
24
|
from mindspore.ops import operations as P
|
|
24
25
|
from mindspore.ops.operations.comm_ops import AllGather
|
|
25
|
-
from mindspore.communication import GlobalComm
|
|
26
|
+
from mindspore.communication import GlobalComm, get_rank
|
|
26
27
|
from mindspore.common import jit
|
|
27
|
-
from mindspore.communication import create_group, destroy_group
|
|
28
|
+
from mindspore.communication import create_group, destroy_group, get_group_size
|
|
28
29
|
from mindspore.communication._comm_helper import _get_group_map
|
|
29
30
|
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
31
|
+
from mindspore.parallel.shard import Layout
|
|
30
32
|
|
|
31
33
|
_ALLGATHER_CELL = None
|
|
34
|
+
ALLREDUCE_GROUP_LIST = []
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
class AllGatherCell(Cell):
|
|
@@ -134,7 +137,7 @@ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
|
|
|
134
137
|
|
|
135
138
|
def _get_group_name(group_map, group):
|
|
136
139
|
"""get group name"""
|
|
137
|
-
group_name = str(group)
|
|
140
|
+
group_name = "remove_redundancy" + str(group)
|
|
138
141
|
is_manual_communication_group = True
|
|
139
142
|
if group_map:
|
|
140
143
|
for name, rank_list in group_map.items():
|
|
@@ -142,20 +145,37 @@ def _get_group_name(group_map, group):
|
|
|
142
145
|
group_name = name
|
|
143
146
|
is_manual_communication_group = False
|
|
144
147
|
break
|
|
145
|
-
if is_manual_communication_group:
|
|
146
|
-
create_group(str(group), list(group))
|
|
147
148
|
return group_name, is_manual_communication_group
|
|
148
149
|
|
|
149
150
|
|
|
150
|
-
def
|
|
151
|
+
def _get_param_redundancy_reversed(param_redundancy, cur_rank):
|
|
152
|
+
"""Generate the reverse mapping of parameter redundancy based on the current rank."""
|
|
153
|
+
param_redundancy_reversed = {}
|
|
154
|
+
for key, redundancy in param_redundancy.items():
|
|
155
|
+
for item in redundancy:
|
|
156
|
+
if len(item) == 1:
|
|
157
|
+
continue
|
|
158
|
+
if cur_rank in item:
|
|
159
|
+
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
160
|
+
return param_redundancy_reversed
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _remove_param_not_load(param_name, param_not_load):
|
|
164
|
+
"""Remove param_name from param_not_load."""
|
|
165
|
+
if param_not_load is not None and param_name in param_not_load:
|
|
166
|
+
param_not_load.remove(param_name)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _single_parameter_broadcast(net, layout, param_not_load=None):
|
|
151
170
|
"""
|
|
152
171
|
Broadcast single parameter to other rank in data parallel dimension.
|
|
153
172
|
"""
|
|
154
173
|
from mindspore import Tensor
|
|
155
174
|
origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
156
175
|
origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
|
|
176
|
+
cur_rank = get_rank()
|
|
157
177
|
if layout:
|
|
158
|
-
param_redundancy = get_parameter_redundancy(layout
|
|
178
|
+
param_redundancy = get_parameter_redundancy(layout)
|
|
159
179
|
else:
|
|
160
180
|
param_redundancy = get_parameter_redundancy(net)
|
|
161
181
|
if not param_redundancy:
|
|
@@ -163,33 +183,130 @@ def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
|
163
183
|
single_params = remove_param_redundancy(param_redundancy)
|
|
164
184
|
if not single_params:
|
|
165
185
|
return
|
|
166
|
-
param_redundancy_reversed =
|
|
167
|
-
for key, redundancy in param_redundancy.items():
|
|
168
|
-
for item in redundancy:
|
|
169
|
-
if len(item) == 1:
|
|
170
|
-
continue
|
|
171
|
-
if cur_rank in item:
|
|
172
|
-
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
186
|
+
param_redundancy_reversed = _get_param_redundancy_reversed(param_redundancy, cur_rank)
|
|
173
187
|
if not param_redundancy_reversed or cur_rank not in single_params:
|
|
174
188
|
return
|
|
175
189
|
net_param_dict = net.parameters_dict()
|
|
176
190
|
_chang_parallel_context(origin_dataset_strategy)
|
|
177
191
|
group_map = _get_group_map()
|
|
192
|
+
if group_map:
|
|
193
|
+
group_map = {key: group_map[key] for key in sorted(group_map.keys())}
|
|
178
194
|
for group, params in param_redundancy_reversed.items():
|
|
179
195
|
group_name, is_manual_communication_group = _get_group_name(group_map, group)
|
|
180
196
|
allreduce_input = []
|
|
181
197
|
for param in params:
|
|
182
198
|
if param not in net_param_dict:
|
|
183
199
|
continue
|
|
200
|
+
if param.startswith("accu_grads") or param.endswith("expert_load"):
|
|
201
|
+
continue
|
|
184
202
|
real_param = net_param_dict[param]
|
|
203
|
+
_remove_param_not_load(real_param.name, param_not_load)
|
|
185
204
|
if param not in single_params[cur_rank]:
|
|
186
205
|
real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
|
|
187
206
|
allreduce_input.append(real_param)
|
|
188
207
|
if not allreduce_input:
|
|
189
208
|
continue
|
|
209
|
+
if is_manual_communication_group:
|
|
210
|
+
create_group(group_name, list(group))
|
|
211
|
+
allreduce_input.sort(key=lambda param: (str(param.shape), str(param.dtype)))
|
|
190
212
|
communicator = SingleCommunicator(group_name)
|
|
191
213
|
for real_param in allreduce_input:
|
|
192
|
-
real_param.set_data(communicator(real_param), real_param.sliced)
|
|
214
|
+
real_param.set_data(communicator(Tensor(real_param)), real_param.sliced)
|
|
193
215
|
if is_manual_communication_group:
|
|
194
216
|
destroy_group(group_name)
|
|
195
217
|
_restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _insert_virtual_pp_dim(layout):
|
|
221
|
+
"""insert virtual pp dim in device matrix and create new layout"""
|
|
222
|
+
if len(layout.to_dict()["rank_list"]) == get_group_size():
|
|
223
|
+
return layout
|
|
224
|
+
remain_pp = get_group_size() // len(layout.to_dict()["rank_list"])
|
|
225
|
+
layout_info = layout.to_dict()
|
|
226
|
+
device_matrix = layout_info["device_matrix"]
|
|
227
|
+
tensor_map = layout_info["tensor_map"]
|
|
228
|
+
alias_name = layout_info["alias_name"]
|
|
229
|
+
new_devmat = Layout((remain_pp,) + device_matrix, ("remain_pp",) + alias_name)
|
|
230
|
+
tensor_map_alias_name = []
|
|
231
|
+
for val in tensor_map:
|
|
232
|
+
sub_alias_name = []
|
|
233
|
+
if isinstance(val, tuple):
|
|
234
|
+
for sub_val in val:
|
|
235
|
+
if sub_val == -1:
|
|
236
|
+
sub_alias_name.append("None")
|
|
237
|
+
else:
|
|
238
|
+
sub_alias_name.append(alias_name[len(device_matrix) - sub_val - 1])
|
|
239
|
+
tensor_map_alias_name.append(tuple(sub_alias_name))
|
|
240
|
+
else:
|
|
241
|
+
if val == -1:
|
|
242
|
+
tensor_map_alias_name.append("None")
|
|
243
|
+
else:
|
|
244
|
+
tensor_map_alias_name.append(alias_name[len(device_matrix) - val - 1])
|
|
245
|
+
new_layout = new_devmat(*tensor_map_alias_name)
|
|
246
|
+
return new_layout
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class CommTensorDataForPP(Cell):
|
|
250
|
+
"""Communicate tensor data for pipeline parallel scenario."""
|
|
251
|
+
|
|
252
|
+
def __init__(self, src_dtensor_info, dst_dtensor_info):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.zeros = P.Zeros()
|
|
255
|
+
|
|
256
|
+
self._current_rank_id = get_rank()
|
|
257
|
+
self._from_dev_num_in_stage = len(src_dtensor_info.layout.to_dict()["rank_list"])
|
|
258
|
+
self._from_rank_id = src_dtensor_info.layout.to_dict()["rank_list"]
|
|
259
|
+
self._current_rank_has_data = self._current_rank_id in src_dtensor_info.layout.to_dict()["rank_list"]
|
|
260
|
+
self._diff_rank_id = [
|
|
261
|
+
rank_id for rank_id in dst_dtensor_info.layout.to_dict()["rank_list"] if rank_id not in self._from_rank_id]
|
|
262
|
+
self._group, self._root_idx = self._create_all_reduce_group()
|
|
263
|
+
|
|
264
|
+
def comm_data(self, comm_data):
|
|
265
|
+
"""communicate data"""
|
|
266
|
+
from mindspore import mint
|
|
267
|
+
comm_handle = mint.distributed.broadcast(comm_data, self._root_idx, self._group, async_op=False)
|
|
268
|
+
return comm_handle
|
|
269
|
+
|
|
270
|
+
def _create_all_reduce_group(self):
|
|
271
|
+
"""create all reduce group"""
|
|
272
|
+
global ALLREDUCE_GROUP_LIST
|
|
273
|
+
current_rank_stage_id = self._current_rank_id // self._from_dev_num_in_stage
|
|
274
|
+
end_stage = self._from_dev_num_in_stage * (current_rank_stage_id + 1)
|
|
275
|
+
rank_pos_in_stage = [rank_id for rank_id in range(self._from_dev_num_in_stage * current_rank_stage_id,
|
|
276
|
+
end_stage)].index(self._current_rank_id)
|
|
277
|
+
root_idx = self._from_rank_id[rank_pos_in_stage]
|
|
278
|
+
all_reduce_rank_list = [self._from_rank_id[rank_pos_in_stage]]
|
|
279
|
+
while rank_pos_in_stage < len(self._diff_rank_id):
|
|
280
|
+
all_reduce_rank_list.append(self._diff_rank_id[rank_pos_in_stage])
|
|
281
|
+
rank_pos_in_stage += self._from_dev_num_in_stage
|
|
282
|
+
all_reduce_rank_list.sort()
|
|
283
|
+
str_rank_list = '-'.join([str(rank) for rank in all_reduce_rank_list])
|
|
284
|
+
all_reduce_group = f"pp_allreduce_group-{str_rank_list}"
|
|
285
|
+
if all_reduce_group in ALLREDUCE_GROUP_LIST:
|
|
286
|
+
return all_reduce_group, root_idx
|
|
287
|
+
ALLREDUCE_GROUP_LIST.append(all_reduce_group)
|
|
288
|
+
create_group(all_reduce_group, all_reduce_rank_list)
|
|
289
|
+
logger.debug(f"Create group {all_reduce_group} for tensor data communication.")
|
|
290
|
+
return all_reduce_group, root_idx
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class RedistributionCell(Cell):
|
|
294
|
+
"""Redistribute src_layout to dst_layout"""
|
|
295
|
+
|
|
296
|
+
def __init__(self, src_layout, dst_layout):
|
|
297
|
+
super().__init__()
|
|
298
|
+
if src_layout is None or dst_layout is None:
|
|
299
|
+
raise ValueError("src_layout and dst_layout should not be None.")
|
|
300
|
+
self._total_dev_num = get_group_size()
|
|
301
|
+
src_layout = _insert_virtual_pp_dim(src_layout)
|
|
302
|
+
dst_layout = _insert_virtual_pp_dim(dst_layout)
|
|
303
|
+
self.src_identity = P.Identity().shard(in_strategy=(src_layout,), out_strategy=(src_layout,))
|
|
304
|
+
self.src_identity.add_prim_attr("self_define_shard", True)
|
|
305
|
+
self.dst_identity = P.Identity().shard(in_strategy=(dst_layout,), out_strategy=(dst_layout,))
|
|
306
|
+
self.dst_identity.add_prim_attr("self_define_shard", True)
|
|
307
|
+
|
|
308
|
+
def construct(self, input_tensor):
|
|
309
|
+
"""run redistribution"""
|
|
310
|
+
src_tensor = self.src_identity(input_tensor)
|
|
311
|
+
dst_tensor = self.dst_identity(src_tensor)
|
|
312
|
+
return dst_tensor
|
|
@@ -19,6 +19,7 @@ import os
|
|
|
19
19
|
import json
|
|
20
20
|
import numpy as np
|
|
21
21
|
import mindspore as ms
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
22
23
|
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
23
24
|
_get_needed_rank_list_by_layouts, _get_needed_rank_transform_operator_map_by_layouts, \
|
|
24
25
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
@@ -34,7 +35,12 @@ def _convert_to_list(strategy, rank_id=None):
|
|
|
34
35
|
try:
|
|
35
36
|
layout = strategy.get(param_name)
|
|
36
37
|
dev_mat = list(layout.dev_matrix[0].dim)
|
|
37
|
-
|
|
38
|
+
# for layout one axis two slices, layout(("dp", "mp"), "None")
|
|
39
|
+
if len(layout.tensor_map) > 1:
|
|
40
|
+
tensor_map = [list(tensor_map.dim) for tensor_map in layout.tensor_map
|
|
41
|
+
if list(tensor_map.dim)]
|
|
42
|
+
else:
|
|
43
|
+
tensor_map = list(layout.tensor_map[0].dim)
|
|
38
44
|
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
39
45
|
field_size = int(layout.field)
|
|
40
46
|
shard_stride = int(layout.opt_weight_shard_step)
|
|
@@ -417,7 +423,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
417
423
|
from_opt_shard_size = 0
|
|
418
424
|
if src_strategy_list is not None:
|
|
419
425
|
if param_name not in src_strategy_list:
|
|
420
|
-
ms.log.
|
|
426
|
+
ms.log.info("The parameter {} is not in src_strategy.".format(param_name))
|
|
421
427
|
continue
|
|
422
428
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
|
423
429
|
src_strategy_list.get(param_name))
|
|
@@ -427,7 +433,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
427
433
|
to_opt_shard_size = 0
|
|
428
434
|
if dst_strategy_list is not None:
|
|
429
435
|
if param_name not in dst_strategy_list:
|
|
430
|
-
ms.log.
|
|
436
|
+
ms.log.info("The parameter {} is not in dst_strategy.".format(param_name))
|
|
431
437
|
continue
|
|
432
438
|
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
433
439
|
dst_strategy_list.get(param_name))
|
|
@@ -441,6 +447,9 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
441
447
|
continue
|
|
442
448
|
origin_tensor_shape += (item * param_strategy[i],)
|
|
443
449
|
|
|
450
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
451
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
452
|
+
|
|
444
453
|
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
445
454
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
446
455
|
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
@@ -460,6 +469,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
460
469
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
461
470
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
462
471
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
472
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
|
|
463
473
|
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
|
|
464
474
|
param_total_dict_copy = param_total_dict[param_name].copy()
|
|
465
475
|
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
|
|
@@ -556,6 +566,32 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
|
|
|
556
566
|
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
|
557
567
|
|
|
558
568
|
|
|
569
|
+
def _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
|
|
570
|
+
insert_from_reshape, insert_to_reshape):
|
|
571
|
+
""" insert layout expand op reshape """
|
|
572
|
+
from_opt_shard_size = from_info_tuple[0]
|
|
573
|
+
from_dev_matrix = from_info_tuple[1]
|
|
574
|
+
from_tensor_map = from_info_tuple[2]
|
|
575
|
+
from_full_tensor_shape = from_info_tuple[3]
|
|
576
|
+
to_opt_shard_size = to_info_tuple[0]
|
|
577
|
+
to_dev_matrix_origin = to_info_tuple[1]
|
|
578
|
+
to_tensor_map_origin = to_info_tuple[2]
|
|
579
|
+
origin_tensor_shape = to_info_tuple[3]
|
|
580
|
+
for param_rank, _ in param_rank_map.items():
|
|
581
|
+
if from_opt_shard_size == 0 and insert_from_reshape:
|
|
582
|
+
from_slice_tensor_shape = ()
|
|
583
|
+
from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
584
|
+
for i, item in enumerate(from_full_tensor_shape):
|
|
585
|
+
from_slice_tensor_shape += (item // from_tensor_strategy[i],)
|
|
586
|
+
param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
|
|
587
|
+
if to_opt_shard_size == 0 and insert_to_reshape:
|
|
588
|
+
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
|
|
589
|
+
to_slice_tensor_shape = ()
|
|
590
|
+
for i, item in enumerate(origin_tensor_shape):
|
|
591
|
+
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
|
592
|
+
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
|
593
|
+
|
|
594
|
+
|
|
559
595
|
def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
|
|
560
596
|
"""Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
|
|
561
597
|
total_device_num = 1
|
|
@@ -569,4 +605,59 @@ def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded
|
|
|
569
605
|
start = rank - offset
|
|
570
606
|
param_total_list = list(range(start, start + range_size))
|
|
571
607
|
return param_total_list
|
|
572
|
-
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
|
|
611
|
+
output_format, name_map, return_param_dict):
|
|
612
|
+
"""gather transform tasks"""
|
|
613
|
+
tasks = []
|
|
614
|
+
for rank in range(0, dst_device_num):
|
|
615
|
+
tasks.append(
|
|
616
|
+
(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
|
|
617
|
+
return_param_dict))
|
|
618
|
+
return tasks
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def _check_checkpoint_file(checkpoint_filenames):
|
|
622
|
+
"""Check checkpoint file name."""
|
|
623
|
+
for index, filename in enumerate(checkpoint_filenames):
|
|
624
|
+
if not isinstance(filename, str) or not os.path.exists(filename) \
|
|
625
|
+
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
|
|
626
|
+
raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
|
|
627
|
+
f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
|
|
628
|
+
f"be a string ending with '.ckpt', and the checkpoint file it represents must "
|
|
629
|
+
f"be exist and not empty.")
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def _check_predict_strategy(predict_strategy):
|
|
633
|
+
"""Check predict strategy."""
|
|
634
|
+
|
|
635
|
+
def _check_int_list(arg):
|
|
636
|
+
if not isinstance(arg, list):
|
|
637
|
+
return False
|
|
638
|
+
for item in arg:
|
|
639
|
+
if not isinstance(item, int):
|
|
640
|
+
return False
|
|
641
|
+
return True
|
|
642
|
+
|
|
643
|
+
if predict_strategy is None:
|
|
644
|
+
return
|
|
645
|
+
|
|
646
|
+
flag = True
|
|
647
|
+
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
|
|
648
|
+
for key in predict_strategy.keys():
|
|
649
|
+
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
|
|
650
|
+
or len(predict_strategy[key]) < 4:
|
|
651
|
+
flag = False
|
|
652
|
+
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
|
|
653
|
+
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
|
|
654
|
+
not (_check_int_list(param_split_shape) or not param_split_shape) or \
|
|
655
|
+
not (isinstance(field_size, int) and field_size == 0):
|
|
656
|
+
flag = False
|
|
657
|
+
|
|
658
|
+
if not flag:
|
|
659
|
+
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
|
|
660
|
+
f"the key of it must be string, and the value of it must be list or tuple that "
|
|
661
|
+
f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
|
|
662
|
+
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
|
|
663
|
+
f"Please check whether 'predict_strategy' is correct.")
|
|
@@ -115,7 +115,7 @@ def _set_ps_context(**kwargs):
|
|
|
115
115
|
enable_ps (bool): Whether to enable parameter server training mode.
|
|
116
116
|
Only after enable_ps is set True, the environment variables will be effective.
|
|
117
117
|
Default: ``False``.
|
|
118
|
-
config_file_path (
|
|
118
|
+
config_file_path (str): Configuration file path used by recovery. Default: ''.
|
|
119
119
|
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
|
120
120
|
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``.
|
|
121
121
|
client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''.
|
|
@@ -33,18 +33,23 @@ def recovery_context():
|
|
|
33
33
|
RECOVERY_CONTEXT = RecoveryContext.get_instance()
|
|
34
34
|
return RECOVERY_CONTEXT
|
|
35
35
|
|
|
36
|
+
|
|
36
37
|
_set_recovery_context_func_map = {
|
|
37
38
|
"ckpt_path": recovery_context().set_ckpt_path,
|
|
38
|
-
"need_reset": recovery_context().set_need_reset
|
|
39
|
+
"need_reset": recovery_context().set_need_reset,
|
|
40
|
+
"is_reboot_node": recovery_context().set_is_reboot_node,
|
|
41
|
+
"is_arf": recovery_context().set_is_arf
|
|
39
42
|
}
|
|
40
43
|
|
|
41
44
|
_get_recovery_context_func_map = {
|
|
42
45
|
"enable_recovery": recovery_context().enable_recovery,
|
|
46
|
+
"enable_repeat_register": recovery_context().enable_repeat_register,
|
|
43
47
|
"latest_ckpt_file": recovery_context().latest_ckpt_file,
|
|
44
48
|
"latest_ckpt_epoch": recovery_context().latest_ckpt_epoch,
|
|
45
49
|
"latest_ckpt_step": recovery_context().latest_ckpt_step,
|
|
46
50
|
"need_reset": recovery_context().need_reset,
|
|
47
51
|
"recovery_path": recovery_context().recovery_path,
|
|
52
|
+
"is_arf": recovery_context().is_arf,
|
|
48
53
|
"ckpt_path": recovery_context().ckpt_path
|
|
49
54
|
}
|
|
50
55
|
|
|
@@ -64,7 +69,7 @@ def _set_recovery_context(**kwargs):
|
|
|
64
69
|
MS_RECOVERY_INTERVAL # The persistent interval for recovery
|
|
65
70
|
|
|
66
71
|
Args:
|
|
67
|
-
ckpt_path (
|
|
72
|
+
ckpt_path (str): Set the recovery path used to save checkpoint. Default: ''.
|
|
68
73
|
need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery.
|
|
69
74
|
Default: ``False``.
|
|
70
75
|
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -38,10 +38,17 @@ def _get_tensor_strategy(dev_mat, tensor_map):
|
|
|
38
38
|
"""
|
|
39
39
|
tensor_strategy = []
|
|
40
40
|
for dim in tensor_map:
|
|
41
|
-
if dim
|
|
42
|
-
|
|
41
|
+
if isinstance(dim, (tuple, list)):
|
|
42
|
+
acc_stra = 1
|
|
43
|
+
for i in dim:
|
|
44
|
+
if i != -1:
|
|
45
|
+
acc_stra *= dev_mat[len(dev_mat) - i - 1]
|
|
46
|
+
tensor_strategy.append(acc_stra)
|
|
43
47
|
else:
|
|
44
|
-
|
|
48
|
+
if dim == -1:
|
|
49
|
+
tensor_strategy.append(1)
|
|
50
|
+
else:
|
|
51
|
+
tensor_strategy.append(dev_mat[-dim - 1])
|
|
45
52
|
return tensor_strategy
|
|
46
53
|
|
|
47
54
|
|
|
@@ -182,7 +189,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
|
|
182
189
|
Args:
|
|
183
190
|
dev_mat (list): The device matrix of devices.
|
|
184
191
|
tensor_map (list): The split strategy of tensor.
|
|
185
|
-
opt_shard_group(
|
|
192
|
+
opt_shard_group(str): The group of optimizer shard
|
|
186
193
|
|
|
187
194
|
Returns:
|
|
188
195
|
Integer, the slice index for slice on this device.
|
|
@@ -388,6 +395,124 @@ def _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
|
|
|
388
395
|
return from_tensor_layout, to_tensor_layout
|
|
389
396
|
|
|
390
397
|
|
|
398
|
+
def _expand_layout(dev_matrix, tensor_map, tensor_shape):
|
|
399
|
+
"""
|
|
400
|
+
expand nested tensor_map and reshape tensor shape according to tensor_map
|
|
401
|
+
dev_matrix = [4, 2, 2]
|
|
402
|
+
tensor_map = [[2, 1], 0]
|
|
403
|
+
tensor_shape = [8, 8]
|
|
404
|
+
=>
|
|
405
|
+
expanded_tensor_map = [2, 1, 0]
|
|
406
|
+
expanded_tensor_map = [4, 8/4, 8]
|
|
407
|
+
"""
|
|
408
|
+
new_tensor_map = []
|
|
409
|
+
new_tensor_shape = []
|
|
410
|
+
for index, dim in enumerate(tensor_map):
|
|
411
|
+
if isinstance(dim, (tuple, list)):
|
|
412
|
+
accu_shape = 1
|
|
413
|
+
for i in range(len(dim) - 1):
|
|
414
|
+
new_tensor_map.append(dim[i])
|
|
415
|
+
new_tensor_shape.append(dev_matrix[len(dev_matrix) - 1 - dim[i]])
|
|
416
|
+
accu_shape *= dev_matrix[len(dev_matrix) - 1 - dim[i]]
|
|
417
|
+
new_tensor_map.append(dim[-1])
|
|
418
|
+
new_tensor_shape.append(tensor_shape[index] // accu_shape)
|
|
419
|
+
else:
|
|
420
|
+
new_tensor_map.append(dim)
|
|
421
|
+
new_tensor_shape.append(tensor_shape[index])
|
|
422
|
+
return dev_matrix, new_tensor_map, new_tensor_shape
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
|
|
426
|
+
origin_full_tensor_shape):
|
|
427
|
+
"""
|
|
428
|
+
Construct tensor layout for optimizer parallel when using layout.
|
|
429
|
+
For example, For Tensor with shape (4,2)
|
|
430
|
+
dev_matrix = [2, 2, 2, 2]
|
|
431
|
+
tensor_map = [[1, 0], -1]
|
|
432
|
+
opt_shard_size = 2
|
|
433
|
+
==>
|
|
434
|
+
dev_matrix = [2, 2, 2, 2]
|
|
435
|
+
tensor_map = [[1, 0], 2, -1]
|
|
436
|
+
the new strategy is [4, 2, 1]
|
|
437
|
+
the tensor_shape should reshape to (model_parallel_size, -1, xx, xx)
|
|
438
|
+
first 4 means the model parallel sharding of data_dim
|
|
439
|
+
second 2 means the opt sharding of data_dim.
|
|
440
|
+
"""
|
|
441
|
+
if opt_shard_step == 0 or opt_shard_size == 0:
|
|
442
|
+
return dev_matrix, tensor_map, list(origin_full_tensor_shape)
|
|
443
|
+
tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
|
|
444
|
+
repeated_dim = []
|
|
445
|
+
dev_sharded_index = []
|
|
446
|
+
dev_matrix, expanded_tensor_map, _ = _expand_layout(dev_matrix, tensor_map, origin_full_tensor_shape)
|
|
447
|
+
for dim in expanded_tensor_map:
|
|
448
|
+
if dim != -1:
|
|
449
|
+
dev_sharded_index.append(len(dev_matrix) - dim - 1)
|
|
450
|
+
for index, value in enumerate(dev_matrix):
|
|
451
|
+
if index not in dev_sharded_index and value > 1:
|
|
452
|
+
repeated_dim.append(index)
|
|
453
|
+
if not repeated_dim:
|
|
454
|
+
raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
|
|
455
|
+
format(dev_matrix, tensor_map))
|
|
456
|
+
return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
|
|
457
|
+
tensor_strategy, repeated_dim)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
|
|
461
|
+
tensor_strategy, repeated_dim):
|
|
462
|
+
"""
|
|
463
|
+
helper function to assign repeated device_matrix dim for opt shard.
|
|
464
|
+
"""
|
|
465
|
+
new_dev_matrix = list(copy.deepcopy(dev_matrix))
|
|
466
|
+
new_dev_matrix_map = list(range(len(dev_matrix)))
|
|
467
|
+
opt_shard_dim = []
|
|
468
|
+
remained_opt_shard_size = opt_shard_size if opt_shard_size != -1 else \
|
|
469
|
+
int(np.prod([dev_matrix[i] for i in repeated_dim]))
|
|
470
|
+
for dim in repeated_dim[::-1]:
|
|
471
|
+
opt_sharding_size = dev_matrix[dim]
|
|
472
|
+
if remained_opt_shard_size // opt_sharding_size == 0:
|
|
473
|
+
if opt_sharding_size % remained_opt_shard_size != 0:
|
|
474
|
+
raise ValueError("dev_matrix value {} at dim {} cannot be divided by needed opt sharding "
|
|
475
|
+
"size {}".format(dev_matrix[dim], len(dev_matrix) - dim - 1,
|
|
476
|
+
remained_opt_shard_size))
|
|
477
|
+
opt_sharding_size = remained_opt_shard_size
|
|
478
|
+
# update dev_matrix
|
|
479
|
+
new_dev_matrix[dim] = dev_matrix[dim] // opt_sharding_size
|
|
480
|
+
new_dev_matrix.insert(dim + 1, opt_sharding_size)
|
|
481
|
+
for i in range(len(dev_matrix) - dim - 1, len(dev_matrix)):
|
|
482
|
+
new_dev_matrix_map[i] += 1
|
|
483
|
+
if remained_opt_shard_size % opt_sharding_size != 0:
|
|
484
|
+
raise ValueError("Remained opt_shard_size {} cannot be divided by current sharding size {}, "
|
|
485
|
+
"the repeat dim is {} with dev_matrix value {}".
|
|
486
|
+
format(remained_opt_shard_size, opt_sharding_size,
|
|
487
|
+
len(dev_matrix) - dim - 1, dev_matrix[dim]))
|
|
488
|
+
remained_opt_shard_size //= opt_sharding_size
|
|
489
|
+
opt_shard_dim.insert(0, dim)
|
|
490
|
+
if remained_opt_shard_size == 1:
|
|
491
|
+
break
|
|
492
|
+
tensor_map_new = list(copy.deepcopy(tensor_map))
|
|
493
|
+
if len(new_dev_matrix) != len(dev_matrix):
|
|
494
|
+
opt_shard_dim = list(map(lambda x: x + 1, opt_shard_dim))
|
|
495
|
+
for index, item in enumerate(tensor_map_new):
|
|
496
|
+
if isinstance(item, (tuple, list)):
|
|
497
|
+
item = list(map(lambda x: new_dev_matrix_map[x] if x >= 0 else x, item))
|
|
498
|
+
tensor_map_new[index] = item
|
|
499
|
+
else:
|
|
500
|
+
if item >= 0:
|
|
501
|
+
tensor_map_new[index] = new_dev_matrix_map[item]
|
|
502
|
+
tensor_shape_new = list(copy.deepcopy(origin_full_tensor_shape))
|
|
503
|
+
tensor_shape_new[0] = tensor_strategy[0]
|
|
504
|
+
first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
|
|
505
|
+
accu_shape = 1
|
|
506
|
+
for i in range(len(opt_shard_dim) - 1):
|
|
507
|
+
opt_sharding_size = new_dev_matrix[opt_shard_dim[i]]
|
|
508
|
+
tensor_shape_new.insert(i + 1, opt_sharding_size)
|
|
509
|
+
accu_shape = accu_shape * opt_sharding_size
|
|
510
|
+
tensor_shape_new.insert(len(opt_shard_dim), first_dim_no_sharding_size // accu_shape)
|
|
511
|
+
for index, r_dim in enumerate(opt_shard_dim):
|
|
512
|
+
tensor_map_new.insert(index + 1, len(new_dev_matrix) - r_dim - 1)
|
|
513
|
+
return list(new_dev_matrix), tensor_map_new, tensor_shape_new
|
|
514
|
+
|
|
515
|
+
|
|
391
516
|
def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
|
|
392
517
|
origin_full_tensor_shape):
|
|
393
518
|
"""
|
|
@@ -404,6 +529,11 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
|
|
|
404
529
|
And the model parallel sharding dim is the right of opt sharding dim, so it would be 0-1-2-3 model parallel sharding
|
|
405
530
|
then 0-4 optimizer sharding.
|
|
406
531
|
"""
|
|
532
|
+
has_layout = any(isinstance(i, (list, tuple)) for i in tensor_map)
|
|
533
|
+
if has_layout:
|
|
534
|
+
output = _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step,
|
|
535
|
+
opt_shard_size, origin_full_tensor_shape)
|
|
536
|
+
return _expand_layout(*output)
|
|
407
537
|
|
|
408
538
|
if opt_shard_step == 0 or opt_shard_size == 0:
|
|
409
539
|
return dev_matrix, tensor_map, list(origin_full_tensor_shape)
|
|
@@ -424,18 +554,8 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
|
|
|
424
554
|
format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
|
|
425
555
|
first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
|
|
426
556
|
if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
accu_shp = 1
|
|
430
|
-
for i in range(len(repeated_dim) - 1):
|
|
431
|
-
opt_sharding_size = dev_matrix[repeated_dim[i]]
|
|
432
|
-
tensor_shape_new.insert(i + 1, opt_sharding_size)
|
|
433
|
-
accu_shp = accu_shp * opt_sharding_size
|
|
434
|
-
tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
|
|
435
|
-
tensor_map_new = list(copy.deepcopy(tensor_map))
|
|
436
|
-
for index, r_dim in enumerate(repeated_dim):
|
|
437
|
-
tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
|
|
438
|
-
return list(dev_matrix), tensor_map_new, tensor_shape_new
|
|
557
|
+
return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
|
|
558
|
+
tensor_strategy, repeated_dim)
|
|
439
559
|
|
|
440
560
|
full_tensor_shape = list(origin_full_tensor_shape)
|
|
441
561
|
full_tensor_shape[0] = tensor_strategy[0]
|
|
@@ -610,9 +730,13 @@ def _apply_operator(operator_name):
|
|
|
610
730
|
"""
|
|
611
731
|
if not isinstance(numpy_data_list, list):
|
|
612
732
|
raise TypeError("The data_list should be a list.")
|
|
733
|
+
new_numpy_data_list = []
|
|
613
734
|
for numpy_data in numpy_data_list:
|
|
614
|
-
if
|
|
615
|
-
|
|
735
|
+
if str(type(numpy_data)) == "<class 'builtins.PySafeSlice'>":
|
|
736
|
+
new_numpy_data_list.append(numpy_data[:])
|
|
737
|
+
else:
|
|
738
|
+
new_numpy_data_list.append(numpy_data)
|
|
739
|
+
numpy_data_list = new_numpy_data_list
|
|
616
740
|
_check_operator(allgather_op)
|
|
617
741
|
concat_group = allgather_op[1][:-1]
|
|
618
742
|
if len(concat_group) != len(numpy_data_list):
|