mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/train/_utils.py
CHANGED
|
@@ -16,22 +16,23 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
import
|
|
20
|
-
from datetime import datetime
|
|
19
|
+
import sys
|
|
21
20
|
import json
|
|
22
21
|
from collections.abc import Iterable
|
|
23
22
|
|
|
23
|
+
import time
|
|
24
24
|
import numpy as np
|
|
25
25
|
|
|
26
26
|
from mindspore.common.tensor import Tensor
|
|
27
|
-
from mindspore._c_expression import
|
|
27
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
28
|
+
from mindspore._c_expression import MSContext, ms_ctx_param
|
|
28
29
|
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
|
29
30
|
from mindspore.common import dtype as mstype
|
|
30
31
|
from mindspore import context
|
|
31
32
|
from mindspore import log as logger
|
|
32
33
|
from mindspore import _checkparam as Validator
|
|
33
34
|
from mindspore.common.api import _cell_graph_executor
|
|
34
|
-
from mindspore.communication import get_group_size
|
|
35
|
+
from mindspore.communication.management import get_rank, get_group_size
|
|
35
36
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
36
37
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
37
38
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
|
|
@@ -65,6 +66,11 @@ def _get_types_and_shapes(dataset):
|
|
|
65
66
|
return dataset_types, dataset_shapes
|
|
66
67
|
|
|
67
68
|
|
|
69
|
+
def enable_data_broadcast():
|
|
70
|
+
"""Get status to indicate if enable dataset broadcast."""
|
|
71
|
+
return MSContext.get_instance().get_param(ms_ctx_param.dataset_broadcast_opt_level) > 0
|
|
72
|
+
|
|
73
|
+
|
|
68
74
|
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
|
|
69
75
|
"""Initialize and execute the dataset graph."""
|
|
70
76
|
batch_size = exec_dataset.get_batch_size()
|
|
@@ -77,15 +83,12 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
|
|
|
77
83
|
if queue_name is None:
|
|
78
84
|
queue_name = str("")
|
|
79
85
|
|
|
86
|
+
# Don't enable dynamic shape(multi-subgraph) feature in pp/data_broadcast mode,
|
|
87
|
+
# otherwise get_data_info will stuck since some rank do not consume data.
|
|
80
88
|
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
89
|
+
data_broadcast = enable_data_broadcast()
|
|
81
90
|
|
|
82
|
-
|
|
83
|
-
dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
|
|
84
|
-
dynamic_sink1 = True
|
|
85
|
-
if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
|
|
86
|
-
dynamic_sink1 = False
|
|
87
|
-
|
|
88
|
-
if use_pipeline_parallel or not dynamic_sink1:
|
|
91
|
+
if use_pipeline_parallel or data_broadcast:
|
|
89
92
|
create_data_info_queue = False
|
|
90
93
|
|
|
91
94
|
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
|
|
@@ -320,9 +323,15 @@ def parse_strategy_ckpt(file_name):
|
|
|
320
323
|
def _get_strategy_opt_shard(param_redundancy_dict, parameter_layout_opt_shard):
|
|
321
324
|
"""Strategy ckpt append opt shard."""
|
|
322
325
|
for key, value in parameter_layout_opt_shard.items():
|
|
323
|
-
if value[1]
|
|
324
|
-
opt_para_num = value[1]
|
|
326
|
+
if value[1] != 0:
|
|
325
327
|
param_redundancy_ranks = param_redundancy_dict.get(key)
|
|
328
|
+
if value[1] != -1:
|
|
329
|
+
opt_para_num = value[1]
|
|
330
|
+
elif param_redundancy_ranks:
|
|
331
|
+
opt_para_num = len(param_redundancy_ranks) * len(param_redundancy_ranks[0]) // value[0]
|
|
332
|
+
else:
|
|
333
|
+
raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
|
|
334
|
+
f"the optimizer is incorrect.")
|
|
326
335
|
res = []
|
|
327
336
|
for param_ranks in param_redundancy_ranks:
|
|
328
337
|
if len(param_ranks) % opt_para_num == 0:
|
|
@@ -374,20 +383,40 @@ def _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundan
|
|
|
374
383
|
param_redundancy_dict[key] = tuple(redundancy_list)
|
|
375
384
|
|
|
376
385
|
|
|
377
|
-
def
|
|
386
|
+
def _get_initial_rank(parameter_layout):
|
|
387
|
+
"""Get the initial rank of pp."""
|
|
388
|
+
for k, _ in parameter_layout.items():
|
|
389
|
+
dev_matrix = parameter_layout[k][0]
|
|
390
|
+
break
|
|
391
|
+
dev_num = 1
|
|
392
|
+
if dev_matrix:
|
|
393
|
+
for i in dev_matrix:
|
|
394
|
+
dev_num *= i
|
|
395
|
+
rank_id = get_rank()
|
|
396
|
+
initial_rank = (rank_id // dev_num) * dev_num
|
|
397
|
+
return initial_rank
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _get_pp_size_from_redundancy_map(param_redundancy):
|
|
401
|
+
"""Get pp size from redundancy map."""
|
|
402
|
+
for _, v in param_redundancy.items():
|
|
403
|
+
return len(v) * len(v[0])
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def get_parameter_redundancy(layout_obj, initial_rank=None):
|
|
378
407
|
"""
|
|
379
408
|
Get parameter redundancy map.
|
|
380
409
|
|
|
381
410
|
Args:
|
|
382
411
|
layout_obj (Union[str, layout): File name of `strategy.ckpt` or net.parameter_layout_dict.
|
|
383
|
-
initial_rank (int): Start rank id for each pipeline. Default:
|
|
412
|
+
initial_rank (int): Start rank id for each pipeline. Default: ``None``.
|
|
384
413
|
|
|
385
414
|
Returns:
|
|
386
415
|
Dict, dict of parameter redundancy info.
|
|
387
416
|
|
|
388
417
|
Examples:
|
|
389
418
|
>>> from mindspore.train.utils import get_parameter_redundancy
|
|
390
|
-
>>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt")
|
|
419
|
+
>>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt", initial_rank=0)
|
|
391
420
|
{'param1': ((0, 1, 2, 3, 4, 5, 6, 7),),
|
|
392
421
|
'param2': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
|
|
393
422
|
'param3': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
|
|
@@ -404,7 +433,8 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
|
|
|
404
433
|
from mindspore.communication.management import get_process_group_ranks
|
|
405
434
|
groups_ranks = (tuple(get_process_group_ranks()),)
|
|
406
435
|
param_redundancy_dict = {param.name: groups_ranks for _, param in layout_obj.parameters_and_names()}
|
|
407
|
-
|
|
436
|
+
sorted_param_redundancy_dict = {key: param_redundancy_dict[key] for key in sorted(param_redundancy_dict.keys())}
|
|
437
|
+
return sorted_param_redundancy_dict
|
|
408
438
|
else:
|
|
409
439
|
parameter_layout = {}
|
|
410
440
|
for k, v in layout_obj.items():
|
|
@@ -412,6 +442,9 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
|
|
|
412
442
|
|
|
413
443
|
param_redundancy_dict = {}
|
|
414
444
|
|
|
445
|
+
if initial_rank is None:
|
|
446
|
+
initial_rank = _get_initial_rank(parameter_layout)
|
|
447
|
+
|
|
415
448
|
_get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank)
|
|
416
449
|
|
|
417
450
|
if isinstance(layout_obj, str):
|
|
@@ -419,7 +452,8 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
|
|
|
419
452
|
else:
|
|
420
453
|
_get_layout_opt_shard(layout_obj, param_redundancy_dict)
|
|
421
454
|
|
|
422
|
-
|
|
455
|
+
sorted_param_redundancy_dict = {key: param_redundancy_dict[key] for key in sorted(param_redundancy_dict.keys())}
|
|
456
|
+
return sorted_param_redundancy_dict
|
|
423
457
|
|
|
424
458
|
|
|
425
459
|
def _collect_settings_by_rank(redundancy_map):
|
|
@@ -514,12 +548,47 @@ def parse_hccl_file(hccl_file_path):
|
|
|
514
548
|
return rankid_dict
|
|
515
549
|
|
|
516
550
|
|
|
517
|
-
def
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
551
|
+
def _progress_bar(iterable, total=None):
|
|
552
|
+
"""
|
|
553
|
+
Decorate an iterable object, returning an iterator which acts exactly
|
|
554
|
+
like the original iterable, but prints a dynamically updating
|
|
555
|
+
progressbar every time a value is requested.
|
|
556
|
+
"""
|
|
557
|
+
if total is None:
|
|
558
|
+
total = len(iterable)
|
|
559
|
+
|
|
560
|
+
start_time = time.time()
|
|
561
|
+
|
|
562
|
+
def print_progress_bar(iteration):
|
|
563
|
+
percent = f"{100 * (iteration / float(total)):.1f}"
|
|
564
|
+
bar_length = 40
|
|
565
|
+
filled_length = int(bar_length * iteration // total)
|
|
566
|
+
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
|
567
|
+
|
|
568
|
+
elapsed_time = time.time() - start_time
|
|
569
|
+
estimated_total_time = elapsed_time / iteration * total
|
|
570
|
+
remaining_time = estimated_total_time - elapsed_time
|
|
571
|
+
|
|
572
|
+
elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
|
|
573
|
+
remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
|
|
574
|
+
|
|
575
|
+
sys.stdout.reconfigure(encoding="utf-8")
|
|
576
|
+
print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
|
|
577
|
+
if iteration == total:
|
|
578
|
+
print()
|
|
579
|
+
|
|
580
|
+
for i, item in enumerate(iterable, start=1):
|
|
581
|
+
yield item
|
|
582
|
+
print_progress_bar(i)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
586
|
+
if load_func is not None:
|
|
587
|
+
param_dict = load_func(path)
|
|
588
|
+
else:
|
|
589
|
+
param_dict = path
|
|
590
|
+
transform_dict = {}
|
|
591
|
+
for k, v in param_dict.items():
|
|
592
|
+
new_name = name_map.get(k, k) if name_map is not None else k
|
|
593
|
+
transform_dict[new_name] = transform_func(v, new_name)
|
|
594
|
+
return transform_dict
|
mindspore/train/amp.py
CHANGED
|
@@ -101,6 +101,7 @@ AMP_AUTO_BLACK_LIST = [
|
|
|
101
101
|
P.LayerNorm,
|
|
102
102
|
gen.LayerNormExt,
|
|
103
103
|
P.BatchNorm,
|
|
104
|
+
gen.BatchNormExt,
|
|
104
105
|
gen.GroupNorm,
|
|
105
106
|
P.KLDivLoss,
|
|
106
107
|
P.SmoothL1Loss,
|
|
@@ -112,6 +113,7 @@ AMP_AUTO_BLACK_LIST = [
|
|
|
112
113
|
P.Pdist,
|
|
113
114
|
P.Cdist,
|
|
114
115
|
P.Renorm,
|
|
116
|
+
gen.MSELossExt,
|
|
115
117
|
]
|
|
116
118
|
|
|
117
119
|
# Indicates which inputs of primitives need to be converted
|
|
@@ -428,15 +430,15 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
428
430
|
|
|
429
431
|
``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
|
|
430
432
|
``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
|
|
431
|
-
``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
|
|
433
|
+
``BatchNormExt``, ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
|
|
432
434
|
``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
|
|
433
435
|
``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
|
|
434
|
-
``Norm``
|
|
436
|
+
``Norm``, ``MSELossExt``
|
|
435
437
|
|
|
436
438
|
Operators in `promote_list` are:
|
|
437
439
|
|
|
438
440
|
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
|
|
439
|
-
``BiasAdd``
|
|
441
|
+
``BiasAdd``, ``AddN``, ``Concat``
|
|
440
442
|
|
|
441
443
|
For details on automatic mixed precision, refer to
|
|
442
444
|
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
|
|
@@ -636,7 +638,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
636
638
|
|
|
637
639
|
|
|
638
640
|
def _is_grad_accumulation(mcell):
|
|
639
|
-
if mcell.cls_name == "GradAccumulationCell":
|
|
641
|
+
if mcell.cls_name == "GradAccumulationCell" or mcell.cls_name == "GradAccumulation":
|
|
640
642
|
return True
|
|
641
643
|
for cell in mcell.cells():
|
|
642
644
|
if _is_grad_accumulation(cell):
|
|
@@ -837,12 +839,14 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
837
839
|
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
838
840
|
can result in a larger network hierarchy and slower performance.
|
|
839
841
|
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
840
|
-
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
842
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` or `level`
|
|
841
843
|
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
842
844
|
- Primitives for blacklist is not support yet.
|
|
843
845
|
|
|
844
846
|
Args:
|
|
845
847
|
network (Cell): Definition of the network.
|
|
848
|
+
|
|
849
|
+
Keyword Args:
|
|
846
850
|
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
|
|
847
851
|
white list is not used.
|
|
848
852
|
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
@@ -36,9 +36,9 @@ from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
|
|
|
36
36
|
from mindspore.train.callback._on_request_exit import OnRequestExit
|
|
37
37
|
from mindspore.train.callback._backup_and_restore import BackupAndRestore
|
|
38
38
|
from mindspore.train.callback._flops_collector import FlopsUtilizationCollector
|
|
39
|
-
from mindspore.train.callback.
|
|
39
|
+
from mindspore.train.callback._train_fault_tolerance import TrainFaultTolerance
|
|
40
40
|
|
|
41
41
|
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "FlopsUtilizationCollector",
|
|
42
42
|
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
|
|
43
43
|
"History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping", "OnRequestExit", "BackupAndRestore",
|
|
44
|
-
"
|
|
44
|
+
"TrainFaultTolerance"]
|
|
@@ -121,10 +121,7 @@ class Callback:
|
|
|
121
121
|
When creating a custom Callback, model context information can be obtained in Callback
|
|
122
122
|
methods by calling `RunContext.original_args()`, which is a dictionary varivable
|
|
123
123
|
recording current attributes. Users can add custimized attributes to the information.
|
|
124
|
-
Training process can also be stopped by calling `request_stop` method.
|
|
125
|
-
of custom Callback, please check
|
|
126
|
-
`Callback tutorial <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/
|
|
127
|
-
callback.html#customized-callback-mechanism>`_.
|
|
124
|
+
Training process can also be stopped by calling `request_stop` method.
|
|
128
125
|
|
|
129
126
|
Examples:
|
|
130
127
|
>>> import numpy as np
|
|
@@ -491,9 +488,7 @@ class RunContext:
|
|
|
491
488
|
|
|
492
489
|
Callback objects not only can obtain the Model context information by calling by
|
|
493
490
|
`RunContext.original_args()` and add extra attributes to the information, but also can stop the
|
|
494
|
-
training process by calling `request_stop` method.
|
|
495
|
-
please check
|
|
496
|
-
`Callback Mechanism <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/callback.html>`_.
|
|
491
|
+
training process by calling `request_stop` method.
|
|
497
492
|
|
|
498
493
|
`RunContext.original_args()` holds the model context information as a dictionary variable, and
|
|
499
494
|
different attributes of the dictionary are stored in training or eval process. Details are as follows:
|
|
@@ -572,10 +567,6 @@ class RunContext:
|
|
|
572
567
|
|
|
573
568
|
Returns:
|
|
574
569
|
Dict, an object that holds the original arguments of model.
|
|
575
|
-
|
|
576
|
-
Tutorial Examples:
|
|
577
|
-
- `Callback Mechanism - Customized Callback Mechanism
|
|
578
|
-
<https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#customized-callback-mechanism>`_
|
|
579
570
|
"""
|
|
580
571
|
return self._original_args
|
|
581
572
|
|
|
@@ -585,11 +576,6 @@ class RunContext:
|
|
|
585
576
|
|
|
586
577
|
Callbacks can use this function to request stop of iterations.
|
|
587
578
|
model.train() checks whether this is called or not.
|
|
588
|
-
|
|
589
|
-
Tutorial Examples:
|
|
590
|
-
- `Callback Mechanism - Customized Training Termination Time
|
|
591
|
-
<https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#
|
|
592
|
-
customized-training-termination-time>`_
|
|
593
579
|
"""
|
|
594
580
|
self._stop_requested = True
|
|
595
581
|
|
|
@@ -18,25 +18,22 @@ from __future__ import absolute_import
|
|
|
18
18
|
import os
|
|
19
19
|
import stat
|
|
20
20
|
import time
|
|
21
|
-
import threading
|
|
22
21
|
|
|
23
22
|
import mindspore.context as context
|
|
24
23
|
from mindspore import log as logger
|
|
25
24
|
from mindspore import nn
|
|
26
25
|
from mindspore import _checkparam as Validator
|
|
27
26
|
from mindspore.train._utils import _make_directory
|
|
28
|
-
from mindspore.train.serialization import save_checkpoint, _save_graph
|
|
27
|
+
from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
|
|
28
|
+
_wait_async_thread_save_ckpt, _check_async_save
|
|
29
29
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
30
30
|
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
31
|
-
from mindspore.
|
|
32
|
-
from mindspore.
|
|
33
|
-
from mindspore.
|
|
34
|
-
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
35
|
-
from mindspore.train.callback._callback import Callback, set_cur_net
|
|
31
|
+
from mindspore.communication.management import get_rank, get_group_size
|
|
32
|
+
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy, _get_pp_size_from_redundancy_map
|
|
33
|
+
from mindspore.train.callback._callback import Callback
|
|
36
34
|
from mindspore.common.tensor import Tensor
|
|
37
35
|
from mindspore.common.parameter import Parameter
|
|
38
36
|
from mindspore.common.generator import Generator
|
|
39
|
-
from mindspore.common.api import _cell_graph_executor
|
|
40
37
|
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
41
38
|
|
|
42
39
|
_cur_dir = os.getcwd()
|
|
@@ -44,15 +41,6 @@ SAVE_DIR = _cur_dir
|
|
|
44
41
|
_info_list = ["epoch_num", "step_num"]
|
|
45
42
|
|
|
46
43
|
|
|
47
|
-
def _wait_async_save_ckpt(async_save=False):
|
|
48
|
-
"""Waiting for asynchronous saving of ckpt to complete."""
|
|
49
|
-
if async_save:
|
|
50
|
-
thread_list = threading.enumerate()
|
|
51
|
-
for thread in thread_list:
|
|
52
|
-
if thread.getName() == "asyn_save_ckpt":
|
|
53
|
-
thread.join()
|
|
54
|
-
|
|
55
|
-
|
|
56
44
|
def _get_dp_tp_from_redundancy(redundancy_tuple):
|
|
57
45
|
"""From redundancy get dp and tp"""
|
|
58
46
|
dp = []
|
|
@@ -76,6 +64,15 @@ def _get_dp_tp_from_layout(parameter_redundancy_dict):
|
|
|
76
64
|
return dp, tp
|
|
77
65
|
|
|
78
66
|
|
|
67
|
+
def _wait_async_save_ckpt(async_save=False):
|
|
68
|
+
"""Waiting for asynchronous saving of ckpt to complete."""
|
|
69
|
+
if async_save:
|
|
70
|
+
if async_save == "process":
|
|
71
|
+
_wait_async_process_save_ckpt()
|
|
72
|
+
else:
|
|
73
|
+
_wait_async_thread_save_ckpt()
|
|
74
|
+
|
|
75
|
+
|
|
79
76
|
def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
80
77
|
"""Check if there is a file with the same name."""
|
|
81
78
|
if callable(prefix) or callable(directory):
|
|
@@ -87,7 +84,7 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
87
84
|
name_ext = os.path.splitext(filename)
|
|
88
85
|
if exception and filename[-16:] != "_breakpoint.ckpt":
|
|
89
86
|
continue
|
|
90
|
-
if not exception and (name_ext[-1]
|
|
87
|
+
if not exception and (name_ext[-1] not in (".ckpt", ".safetensors") or filename[-16:] == "_breakpoint.ckpt"):
|
|
91
88
|
continue
|
|
92
89
|
# find same prefix file
|
|
93
90
|
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
|
|
@@ -106,10 +103,10 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
106
103
|
return prefix
|
|
107
104
|
|
|
108
105
|
|
|
109
|
-
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False,
|
|
106
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, exception_save=False,
|
|
110
107
|
map_param_inc=False, global_step_num=None):
|
|
111
|
-
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or
|
|
112
|
-
or
|
|
108
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or exception_save or map_param_inc
|
|
109
|
+
or global_step_num is not None)
|
|
113
110
|
if format == "safetensors" and param_not_default:
|
|
114
111
|
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
115
112
|
|
|
@@ -139,7 +136,10 @@ class CheckpointConfig:
|
|
|
139
136
|
integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
|
|
140
137
|
Integrated save function is only supported in automatic parallel scene, not supported
|
|
141
138
|
in manual parallel. Default: ``True`` .
|
|
142
|
-
async_save (bool):
|
|
139
|
+
async_save (Union[bool, str], optional):Whether to use asynchronous saving of the checkpoint file or
|
|
140
|
+
safetensors file, if True, the asynchronous thread is used by default. If the type
|
|
141
|
+
is string, the method of asynchronous saving, it can be "process" or "thread".
|
|
142
|
+
Default: ``False`` .
|
|
143
143
|
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
|
144
144
|
with the network in training, the initial value of saved_network will be saved. Default: ``None`` .
|
|
145
145
|
append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and
|
|
@@ -247,7 +247,7 @@ class CheckpointConfig:
|
|
|
247
247
|
self._keep_checkpoint_max = 1
|
|
248
248
|
|
|
249
249
|
self._integrated_save = Validator.check_bool(integrated_save)
|
|
250
|
-
self._async_save =
|
|
250
|
+
self._async_save = _check_async_save(async_save)
|
|
251
251
|
self._saved_network = saved_network
|
|
252
252
|
self._append_dict = self._handle_append_info(append_info)
|
|
253
253
|
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
@@ -258,8 +258,7 @@ class CheckpointConfig:
|
|
|
258
258
|
self.enable_redundance = kwargs.get('enable_redundance', False)
|
|
259
259
|
self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
260
260
|
|
|
261
|
-
_check_format_and_other_params(format, enc_key, enc_mode, crc_check,
|
|
262
|
-
self._map_param_inc)
|
|
261
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, exception_save, self._map_param_inc)
|
|
263
262
|
|
|
264
263
|
@property
|
|
265
264
|
def save_checkpoint_steps(self):
|
|
@@ -313,10 +312,10 @@ class CheckpointConfig:
|
|
|
313
312
|
@property
|
|
314
313
|
def async_save(self):
|
|
315
314
|
"""
|
|
316
|
-
Get the value of whether asynchronous execution saves the checkpoint to a file.
|
|
315
|
+
Get the value of whether or how asynchronous execution saves the checkpoint to a file.
|
|
317
316
|
|
|
318
317
|
Returns:
|
|
319
|
-
bool, whether asynchronous execution saves the checkpoint to a file.
|
|
318
|
+
(bool, str), whether or how asynchronous execution saves the checkpoint to a file.
|
|
320
319
|
"""
|
|
321
320
|
return self._async_save
|
|
322
321
|
|
|
@@ -449,8 +448,9 @@ class ModelCheckpoint(Callback):
|
|
|
449
448
|
Note:
|
|
450
449
|
In the distributed training scenario, please specify different directories for each training process
|
|
451
450
|
to save the checkpoint file. Otherwise, the training may fail.
|
|
452
|
-
If this callback is used in the
|
|
453
|
-
|
|
451
|
+
If this callback is used in the
|
|
452
|
+
`Model <https://www.mindspore.cn/docs/en/master/api_python/train/mindspore.train.Model.html>`_ function,
|
|
453
|
+
the checkpoint file will saved parameters of the optimizer by default.
|
|
454
454
|
|
|
455
455
|
Args:
|
|
456
456
|
prefix (Union[str, callable object]): The prefix name or callable object to generate name of checkpoint files.
|
|
@@ -511,7 +511,7 @@ class ModelCheckpoint(Callback):
|
|
|
511
511
|
if callable(prefix):
|
|
512
512
|
self._prefix_func = prefix
|
|
513
513
|
|
|
514
|
-
if _get_recovery_context("enable_recovery"):
|
|
514
|
+
if context.get_context("device_target") == "GPU" and _get_recovery_context("enable_recovery"):
|
|
515
515
|
_set_recovery_context(ckpt_path=self._directory)
|
|
516
516
|
|
|
517
517
|
if config is None:
|
|
@@ -538,6 +538,8 @@ class ModelCheckpoint(Callback):
|
|
|
538
538
|
self._graph_saved = False
|
|
539
539
|
self._need_flush_from_cache = True
|
|
540
540
|
self._map_param_inc = self._config.map_param_inc
|
|
541
|
+
self._d2h_async = os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
542
|
+
self._run_mode = context.get_context("mode")
|
|
541
543
|
|
|
542
544
|
def step_end(self, run_context):
|
|
543
545
|
"""
|
|
@@ -551,19 +553,17 @@ class ModelCheckpoint(Callback):
|
|
|
551
553
|
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
552
554
|
ckpt_storage_path = self._directory
|
|
553
555
|
rank_id = get_rank()
|
|
554
|
-
|
|
555
|
-
stage_rank_num = _get_device_num() // stage_num
|
|
556
|
+
device_num = get_group_size()
|
|
556
557
|
param_layout = cb_params.train_network.parameter_layout_dict
|
|
557
558
|
if not param_layout:
|
|
558
|
-
layout = {"stage_num":
|
|
559
|
+
layout = {"stage_num": 1, "stage_rank_num": device_num, "stage_layout": None}
|
|
559
560
|
aiturbo.init(ckpt_storage_path, rank_id, layout, None, False, None)
|
|
560
561
|
else:
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
|
|
562
|
+
param_redundancy_dict = get_parameter_redundancy(param_layout)
|
|
563
|
+
pp_size = _get_pp_size_from_redundancy_map(param_redundancy_dict)
|
|
564
|
+
stage_num = device_num // pp_size
|
|
565
565
|
dp, _ = _get_dp_tp_from_layout(param_redundancy_dict)
|
|
566
|
-
layout = {"stage_num": stage_num, "stage_rank_num":
|
|
566
|
+
layout = {"stage_num": stage_num, "stage_rank_num": pp_size,
|
|
567
567
|
"stage_layout": param_redundancy_dict}
|
|
568
568
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
569
569
|
single_params = {device_id: list(params) for device_id, params in single_params.items()}
|
|
@@ -632,6 +632,13 @@ class ModelCheckpoint(Callback):
|
|
|
632
632
|
if "step_num" in self._append_dict:
|
|
633
633
|
self._append_dict["step_num"] = self._append_step_num + step_num
|
|
634
634
|
|
|
635
|
+
def _update_save_step(self, cb_params):
|
|
636
|
+
"""update step if used async d2h copy"""
|
|
637
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
638
|
+
if self._d2h_async and self._run_mode == context.GRAPH_MODE:
|
|
639
|
+
step_num_in_epoch -= 1
|
|
640
|
+
return step_num_in_epoch
|
|
641
|
+
|
|
635
642
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
|
636
643
|
"""Save checkpoint files."""
|
|
637
644
|
if cb_params.cur_step_num == self._last_triggered_step:
|
|
@@ -642,10 +649,12 @@ class ModelCheckpoint(Callback):
|
|
|
642
649
|
self._flush_from_cache(cb_params)
|
|
643
650
|
|
|
644
651
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
|
645
|
-
step_num_in_epoch =
|
|
652
|
+
step_num_in_epoch = self._update_save_step(cb_params)
|
|
646
653
|
|
|
647
654
|
if save_ckpt:
|
|
655
|
+
|
|
648
656
|
_wait_async_save_ckpt(self._config.async_save)
|
|
657
|
+
|
|
649
658
|
if self._prefix_func:
|
|
650
659
|
cur_ckpoint_file = self._prefix + f".{self._config.format}"
|
|
651
660
|
else:
|
|
@@ -670,12 +679,6 @@ class ModelCheckpoint(Callback):
|
|
|
670
679
|
self._last_time_for_keep = time.time()
|
|
671
680
|
self._last_triggered_step = cb_params.cur_step_num
|
|
672
681
|
|
|
673
|
-
# TODO(MS_DISABLE_REF_MODE): Delete when remove MS_DISABLE_REF_MODE env.
|
|
674
|
-
if context.get_context("enable_ge") and os.getenv('MS_DISABLE_REF_MODE') \
|
|
675
|
-
and context.get_context("mode") == context.GRAPH_MODE:
|
|
676
|
-
set_cur_net(cb_params.train_network)
|
|
677
|
-
cb_params.train_network.add_flags(ge_sync_data=True)
|
|
678
|
-
_cell_graph_executor(cb_params.train_network, phase='save')
|
|
679
682
|
self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num)
|
|
680
683
|
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
|
681
684
|
if os.getenv("AITURBO") == "1":
|
|
@@ -684,18 +687,13 @@ class ModelCheckpoint(Callback):
|
|
|
684
687
|
crc_check=self._config.crc_check, incremental=self._map_param_inc,
|
|
685
688
|
global_step_num=cb_params.cur_step_num)
|
|
686
689
|
elif self._config.remove_redundancy:
|
|
687
|
-
|
|
688
|
-
if parallel_mode == "stand_alone":
|
|
690
|
+
if get_group_size() == 1:
|
|
689
691
|
raise TypeError(f"The deduplication feature for saving checkpoint can only be used "
|
|
690
|
-
f"in parallel scenarios, but got
|
|
692
|
+
f"in parallel scenarios, but got 'stand_alone'.")
|
|
691
693
|
param_layout = network.parameter_layout_dict
|
|
692
694
|
rank_id = get_rank()
|
|
693
695
|
if param_layout:
|
|
694
|
-
|
|
695
|
-
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
696
|
-
chunk_size = device_num // stage_num
|
|
697
|
-
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
698
|
-
param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
|
|
696
|
+
param_redundancy_dict = get_parameter_redundancy(param_layout)
|
|
699
697
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
700
698
|
save_param_names = single_params.get(rank_id)
|
|
701
699
|
param_layout_set = set(param_layout.keys())
|
|
@@ -704,14 +702,14 @@ class ModelCheckpoint(Callback):
|
|
|
704
702
|
f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
|
|
705
703
|
|
|
706
704
|
def choice_func(x):
|
|
707
|
-
return x not in param_layout_set or x in save_param_names
|
|
705
|
+
return x not in param_layout_set or (save_param_names is not None and x in save_param_names)
|
|
708
706
|
else:
|
|
709
707
|
param_redundancy_dict = get_parameter_redundancy(network)
|
|
710
708
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
711
709
|
save_param_names = single_params.get(rank_id)
|
|
712
710
|
|
|
713
711
|
def choice_func(x):
|
|
714
|
-
return x in save_param_names
|
|
712
|
+
return save_param_names is not None and x in save_param_names
|
|
715
713
|
save_checkpoint(network, cur_file, False, self._config.async_save,
|
|
716
714
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
717
715
|
crc_check=self._config.crc_check, format=self._config.format,
|
|
@@ -24,9 +24,8 @@ from threading import RLock
|
|
|
24
24
|
from mindspore.train.callback._callback import Callback
|
|
25
25
|
from mindspore.communication.management import get_rank, get_local_rank
|
|
26
26
|
from mindspore import log as logger
|
|
27
|
-
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
28
27
|
from mindspore.parallel._utils import _get_device_num
|
|
29
|
-
from mindspore.train._utils import get_parameter_redundancy
|
|
28
|
+
from mindspore.train._utils import get_parameter_redundancy, _get_pp_size_from_redundancy_map
|
|
30
29
|
|
|
31
30
|
_perf_mutex = RLock()
|
|
32
31
|
|
|
@@ -42,7 +41,7 @@ def _get_dp_tp_from_redundancy(redundancy_tuple):
|
|
|
42
41
|
return dp, tp
|
|
43
42
|
|
|
44
43
|
|
|
45
|
-
def _get_dp_tp_from_layout(parameter_layout_dict, initial_rank=
|
|
44
|
+
def _get_dp_tp_from_layout(parameter_layout_dict, initial_rank=None):
|
|
46
45
|
"""From layout dict get dp and tp"""
|
|
47
46
|
tp = []
|
|
48
47
|
dp = []
|
|
@@ -132,21 +131,9 @@ class ClusterMonitor(Callback):
|
|
|
132
131
|
self.full_path = self.log_path + self.log_name
|
|
133
132
|
|
|
134
133
|
self.write_dp_tp_flag = True
|
|
135
|
-
self.initial_rank = 0
|
|
136
134
|
|
|
137
135
|
def begin(self, run_context):
|
|
138
136
|
_remove_pre_log()
|
|
139
|
-
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
140
|
-
device_num = _get_device_num()
|
|
141
|
-
|
|
142
|
-
original_list = list(range(device_num))
|
|
143
|
-
chunk_size = device_num // pp_num
|
|
144
|
-
split_pp_lists = []
|
|
145
|
-
for i in range(0, device_num, chunk_size):
|
|
146
|
-
end_index = i + chunk_size if i + chunk_size <= device_num else device_num
|
|
147
|
-
split_pp_lists.append(original_list[i:end_index])
|
|
148
|
-
|
|
149
|
-
self.initial_rank = (self.global_rank // chunk_size) * chunk_size
|
|
150
137
|
with _perf_mutex:
|
|
151
138
|
dir_path = os.path.dirname(self.full_path)
|
|
152
139
|
if not os.path.exists(dir_path):
|
|
@@ -157,8 +144,6 @@ class ClusterMonitor(Callback):
|
|
|
157
144
|
with open(self.full_path, 'w') as file:
|
|
158
145
|
log_message = f'UUID:{self.uuid_value}\nFRAMEWORK:{self.frame_work}\nGLOBAL RANKID:{self.global_rank}\n'
|
|
159
146
|
file.write(log_message)
|
|
160
|
-
for _, split_pp_list in enumerate(split_pp_lists):
|
|
161
|
-
file.write(f'PP:{split_pp_list}\n')
|
|
162
147
|
os.chmod(self.full_path, stat.S_IRUSR)
|
|
163
148
|
|
|
164
149
|
def step_begin(self, run_context):
|
|
@@ -183,10 +168,21 @@ class ClusterMonitor(Callback):
|
|
|
183
168
|
if self.enabled and self.enabled_dtp_group and self.write_dp_tp_flag:
|
|
184
169
|
cb_params = run_context.original_args()
|
|
185
170
|
param_layout_dict = cb_params.train_network.parameter_layout_dict
|
|
186
|
-
|
|
171
|
+
device_num = _get_device_num()
|
|
172
|
+
original_list = list(range(device_num))
|
|
173
|
+
param_redundancy_dict = get_parameter_redundancy(param_layout_dict)
|
|
174
|
+
pp_size = _get_pp_size_from_redundancy_map(param_redundancy_dict)
|
|
175
|
+
split_pp_lists = []
|
|
176
|
+
for i in range(0, device_num, pp_size):
|
|
177
|
+
end_index = i + pp_size if i + pp_size <= device_num else device_num
|
|
178
|
+
split_pp_lists.append(original_list[i:end_index])
|
|
179
|
+
dp, tp = _get_dp_tp_from_layout(param_layout_dict)
|
|
180
|
+
|
|
187
181
|
with _perf_mutex:
|
|
188
182
|
os.chmod(self.full_path, stat.S_IWUSR)
|
|
189
183
|
with open(self.full_path, 'a') as file:
|
|
184
|
+
for _, split_pp_list in enumerate(split_pp_lists):
|
|
185
|
+
file.write(f'PP:{split_pp_list}\n')
|
|
190
186
|
for dp_value in dp:
|
|
191
187
|
file.write(f'dp:{dp_value}\n')
|
|
192
188
|
for tp_value in tp:
|