mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -107,22 +107,18 @@ class SummaryCollector(Callback):
|
|
|
107
107
|
The first output will be treated as the loss and it will be averaged. Default: ``True`` .
|
|
108
108
|
- collect_graph (bool): Whether to collect the computational graph. Currently, only
|
|
109
109
|
training computational graph is collected. Default: ``True`` .
|
|
110
|
-
- collect_train_lineage (bool): Whether to collect lineage data for the training phase
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
- collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase,
|
|
115
|
-
this field will be displayed on the `lineage page
|
|
116
|
-
<https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_
|
|
117
|
-
of MindInsight. Default: ``True`` .
|
|
110
|
+
- collect_train_lineage (bool): Whether to collect lineage data for the training phase.
|
|
111
|
+
Default: ``True`` .
|
|
112
|
+
- collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase.
|
|
113
|
+
Default: ``True`` .
|
|
118
114
|
- collect_input_data (bool): Whether to collect dataset for each training.
|
|
119
115
|
Currently only image data is supported.
|
|
120
116
|
If there are multiple columns of data in the dataset, the first column should be image data.
|
|
121
117
|
Default: ``True`` .
|
|
122
118
|
- collect_dataset_graph (bool): Whether to collect dataset graph for the training phase.
|
|
123
119
|
Default: ``True`` .
|
|
124
|
-
- histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page
|
|
125
|
-
|
|
120
|
+
- histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page.
|
|
121
|
+
This field allows regular strings to control which parameters to collect.
|
|
126
122
|
It is not recommended to collect too many parameters at once, as it can affect performance.
|
|
127
123
|
Note that if you collect too many parameters and run out of memory, the training will fail.
|
|
128
124
|
Default: ``None`` , it means only the first five parameters are collected.
|
|
@@ -153,8 +149,7 @@ class SummaryCollector(Callback):
|
|
|
153
149
|
True: it means that after specified data is set, non-specified data is collected as the default behavior.
|
|
154
150
|
False: it means that after specified data is set, only the specified data is collected,
|
|
155
151
|
and the others are not collected. Default: ``True`` .
|
|
156
|
-
custom_lineage_data (Union[dict, None]): Allows you to customize the data
|
|
157
|
-
`lineage page <https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_ .
|
|
152
|
+
custom_lineage_data (Union[dict, None]): Allows you to customize the data.
|
|
158
153
|
In the custom data, the type of the key supports str, and the type of value supports str, int
|
|
159
154
|
and float. Default: ``None`` , it means there is no custom data.
|
|
160
155
|
collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
|
|
@@ -168,7 +163,7 @@ class SummaryCollector(Callback):
|
|
|
168
163
|
affect the number of steps TensorSummary will be collected.
|
|
169
164
|
Default: ``None`` , which means to follow the behavior as described above.
|
|
170
165
|
max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
|
|
171
|
-
For example, to write not larger than 4GB, specify `max_file_size=4*1024
|
|
166
|
+
For example, to write not larger than 4GB, specify `max_file_size=4*1024*3`.
|
|
172
167
|
Default: ``None`` , which means no limit.
|
|
173
168
|
export_options (Union[None, dict]): Perform custom operations on the export data.
|
|
174
169
|
Note that the size of export files is not limited by the max_file_size.
|
|
@@ -28,7 +28,8 @@ class TimeMonitor(Callback):
|
|
|
28
28
|
Args:
|
|
29
29
|
data_size (int): How many steps are the intervals between print information each time.
|
|
30
30
|
if the program get `batch_num` during training, `data_size` will be set to `batch_num`,
|
|
31
|
-
otherwise `data_size` will be used.
|
|
31
|
+
otherwise `data_size` will be used. If the program does not get `batch_num` during training,
|
|
32
|
+
meanwhile `data_size` does not set, the program will report an error. Default: ``None`` .
|
|
32
33
|
|
|
33
34
|
data_time (bool): Whether to show the average time of fetching data in Host.
|
|
34
35
|
Note that data fetch and network compute are processed sequentially in non dataset sink mode, while
|
|
@@ -15,24 +15,27 @@
|
|
|
15
15
|
"""Checkpoint related classes and functions."""
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
|
+
from mindspore.utils import _tft_handler
|
|
18
19
|
from mindspore.train.serialization import save_checkpoint
|
|
19
|
-
from mindspore.parallel._utils import _get_device_num
|
|
20
|
-
from mindspore import _checkparam as Validator
|
|
21
20
|
from mindspore.train.callback._callback import Callback
|
|
22
|
-
from mindspore import context
|
|
21
|
+
from mindspore import context, ops
|
|
23
22
|
from mindspore.common.parameter import Parameter
|
|
24
23
|
from mindspore.common.tensor import Tensor
|
|
25
24
|
from mindspore.communication import get_rank, get_group_size
|
|
26
25
|
from mindspore import log as logger
|
|
27
26
|
from mindspore.train.serialization import _get_cur_rank_dp
|
|
28
27
|
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
|
|
28
|
+
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
30
|
from mindspore._c_expression import send_recv, reset_params
|
|
31
31
|
from mindspore._c_expression import CollectiveManager
|
|
32
32
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
33
|
-
from mindspore._c_expression import
|
|
33
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
34
|
+
from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
34
35
|
import mindspore
|
|
35
36
|
import mindspore.common.dtype as mstype
|
|
37
|
+
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
38
|
+
|
|
36
39
|
|
|
37
40
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
38
41
|
""" Common func to generate ckpt dir name."""
|
|
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
|
40
43
|
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
|
|
41
44
|
return os.path.join(ckpt_save_path, mid_dir)
|
|
42
45
|
|
|
46
|
+
|
|
43
47
|
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
44
48
|
""" Callback used for TFT save ckpt function when errors occur."""
|
|
45
49
|
logger.info("Enter _save_checkpoint_on_failure function")
|
|
46
|
-
if not cb_ctx._is_params_consistent():
|
|
50
|
+
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
47
51
|
raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
|
|
52
|
+
cb_params = args
|
|
53
|
+
# we record the current step and epoch num in on_train_step_end, so we can just reset it here
|
|
54
|
+
cb_params.cur_step_num = cb_ctx.cur_step_num
|
|
55
|
+
cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
|
|
56
|
+
if cb_params.optimizer is not None:
|
|
57
|
+
cb_params.optimizer.global_step = cb_ctx.global_step
|
|
58
|
+
if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
59
|
+
cb_params.network.optimizer.global_step = cb_ctx.global_step
|
|
60
|
+
append_dict = {}
|
|
61
|
+
append_dict["__exception_save__"] = True
|
|
62
|
+
# if user has provided a custom save callback, use it
|
|
63
|
+
if cb_ctx.save_cb:
|
|
64
|
+
cb_ctx.save_cb(cb_params, append_dict)
|
|
65
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
66
|
+
return
|
|
48
67
|
|
|
68
|
+
# if user has not provided a custom save callback, use default save logic
|
|
49
69
|
ckpt_save_path = cb_ctx.ckpt_save_path
|
|
50
|
-
cb_params = args
|
|
51
70
|
cur_rank = get_rank()
|
|
52
|
-
|
|
71
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
53
72
|
cur_epoch_num = cb_params.cur_epoch_num
|
|
54
|
-
batch_num = cb_params.batch_num
|
|
55
|
-
if cur_step_num > step:
|
|
56
|
-
cur_epoch_num = (step - 1) // batch_num + 1
|
|
57
|
-
step_num_in_epoch = int((step - 1) % batch_num + 1)
|
|
58
|
-
|
|
59
|
-
append_dict = {}
|
|
60
73
|
append_dict["epoch_num"] = cur_epoch_num
|
|
61
|
-
append_dict["step_num"] =
|
|
74
|
+
append_dict["step_num"] = cb_params.cur_step_num
|
|
62
75
|
append_dict["cur_rank"] = cur_rank
|
|
63
|
-
append_dict["batch_num"] = batch_num
|
|
64
|
-
append_dict["
|
|
65
|
-
|
|
66
|
-
append_dict["global_step"] = Parameter([cb_ctx.global_step])
|
|
76
|
+
append_dict["batch_num"] = cb_params.batch_num
|
|
77
|
+
append_dict["global_step"] = cb_ctx.global_step
|
|
67
78
|
outputs = cb_params.net_outputs
|
|
68
79
|
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
|
|
69
80
|
append_dict["loss_scale"] = outputs[2]
|
|
@@ -76,49 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
|
76
87
|
integrated_save=False, append_dict=append_dict)
|
|
77
88
|
logger.info("Finish _save_checkpoint_on_failure function")
|
|
78
89
|
|
|
90
|
+
|
|
79
91
|
def _rename_save_result(step, cb_ctx):
|
|
80
92
|
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
|
|
81
93
|
logger.info("Enter _rename_save_result function")
|
|
94
|
+
if cb_ctx.save_cb:
|
|
95
|
+
logger.info("User's save callback is provided, skip rename")
|
|
96
|
+
return
|
|
82
97
|
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
|
|
83
98
|
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
|
|
84
99
|
|
|
85
100
|
os.rename(tmp_dir, fin_dir)
|
|
86
101
|
logger.info("Finish _rename_save_result function")
|
|
87
102
|
|
|
103
|
+
|
|
88
104
|
def _tft_exit_cb(ctx):
|
|
105
|
+
"""Callback used for TFT exit function."""
|
|
89
106
|
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
|
|
90
107
|
_tft_sem_post()
|
|
91
|
-
os._exit(1)
|
|
108
|
+
os._exit(1) # pylint: disable=W0212
|
|
92
109
|
|
|
93
110
|
|
|
94
111
|
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
95
112
|
""" Callback used for TFT repair function."""
|
|
96
|
-
logger.
|
|
97
|
-
if(repair_info["repair_type"]
|
|
98
|
-
|
|
99
|
-
logger.
|
|
113
|
+
logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
114
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
115
|
+
cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
|
|
116
|
+
logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
100
117
|
_repair_device(cb_ctx.device_id)
|
|
101
118
|
|
|
102
|
-
if(repair_info["repair_type"]
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
{}, src_rank:{}, dst_rank: {}".format(
|
|
119
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
120
|
+
cb_ctx.tft.RepairType.RT_SEND.value,
|
|
121
|
+
cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
|
|
122
|
+
logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
|
|
123
|
+
repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
|
|
106
124
|
cb_params = args
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
125
|
+
if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
|
|
126
|
+
for i in range(len(repair_info["src"])):
|
|
127
|
+
src_rank = repair_info["src"][i]
|
|
128
|
+
dst_rank = repair_info["dst"][i]
|
|
129
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
130
|
+
raise ValueError("Call send_recv failed.")
|
|
131
|
+
else:
|
|
132
|
+
src_rank = repair_info["src"][0]
|
|
133
|
+
dst_rank = repair_info["dst"][0]
|
|
134
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
135
|
+
raise ValueError("Call send_recv failed.")
|
|
136
|
+
logger.warning("Finish _tft_repair_callback")
|
|
112
137
|
|
|
113
138
|
|
|
114
139
|
def _tft_clean_callback(is_uce_error, args, ctx):
|
|
115
140
|
""" Callback used for TFT clean function."""
|
|
116
|
-
logger.
|
|
141
|
+
logger.warning("Enter _tft_clean_callback")
|
|
117
142
|
ret = 0
|
|
118
143
|
if is_uce_error:
|
|
119
144
|
_get_uce_mem_info(ctx.device_id)
|
|
120
145
|
err_strategy = _get_uce_process_strategy()
|
|
121
|
-
logger.
|
|
146
|
+
logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
|
|
122
147
|
if err_strategy == "RS_UCE_HIGHLEVEL":
|
|
123
148
|
ret = 0
|
|
124
149
|
elif err_strategy == "RS_UCE_LOWLEVEL":
|
|
@@ -126,37 +151,49 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
126
151
|
else:
|
|
127
152
|
ret = 1
|
|
128
153
|
clean_tdt_channel()
|
|
129
|
-
logger.
|
|
154
|
+
logger.warning("Enter _tft_clean_callback resume_hccl_comm")
|
|
130
155
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
131
|
-
logger.
|
|
156
|
+
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
132
157
|
return ret
|
|
133
158
|
|
|
134
159
|
|
|
135
160
|
def _tft_stop_callback(args, cb_ctx):
|
|
136
161
|
""" Callback used for TFT stop function."""
|
|
137
|
-
logger.
|
|
162
|
+
logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
138
163
|
_stop_device(cb_ctx.device_id)
|
|
139
|
-
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()):
|
|
164
|
+
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
140
165
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
141
166
|
cb_ctx.is_uce_rank = False
|
|
167
|
+
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
168
|
+
logger.warning(f"Reset limit step")
|
|
169
|
+
cb_ctx.tft.tft_reset_limit_step()
|
|
142
170
|
logger.info("Finish _tft_stop_callback")
|
|
143
171
|
|
|
144
172
|
|
|
145
|
-
|
|
173
|
+
def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
174
|
+
"""Callback used for TFT Rebuild Group function."""
|
|
175
|
+
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
|
|
176
|
+
_finalize_comm()
|
|
177
|
+
_rebuild_world_group()
|
|
178
|
+
_rebuild_sub_group()
|
|
179
|
+
_set_recovery_context(is_arf=True)
|
|
180
|
+
logger.warning("Enter _tft_rebuild_sub_groups ok ")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TrainFaultTolerance(Callback):
|
|
146
184
|
"""
|
|
147
185
|
This callback is used to enable the TFT feature
|
|
148
|
-
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
149
|
-
|
|
186
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
187
|
+
and will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
150
188
|
|
|
151
189
|
Note:
|
|
152
190
|
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
|
|
153
191
|
|
|
154
192
|
Args:
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
|
|
193
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
|
|
194
|
+
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
|
|
195
|
+
is created in that directory. Default: ``None``.
|
|
196
|
+
kwargs (dict): Other dictionary type parameters.
|
|
160
197
|
|
|
161
198
|
Raises:
|
|
162
199
|
Exception: TFT init failed.
|
|
@@ -168,7 +205,7 @@ class TFTRegister(Callback):
|
|
|
168
205
|
|
|
169
206
|
It's recommended to use the msrun startup method.
|
|
170
207
|
Please see the `msrun start up
|
|
171
|
-
<https://www.mindspore.cn/
|
|
208
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
172
209
|
for more details.
|
|
173
210
|
|
|
174
211
|
This example should be run with 4 devices.
|
|
@@ -181,7 +218,7 @@ class TFTRegister(Callback):
|
|
|
181
218
|
>>> from mindspore import nn, ops, Parameter, train
|
|
182
219
|
>>> from mindspore.communication import init, get_rank
|
|
183
220
|
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
184
|
-
>>> from mindspore.train import Model,
|
|
221
|
+
>>> from mindspore.train import Model, TrainFaultTolerance
|
|
185
222
|
>>> from mindspore import dataset as ds
|
|
186
223
|
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
187
224
|
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
@@ -252,43 +289,68 @@ class TFTRegister(Callback):
|
|
|
252
289
|
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
253
290
|
>>> loss_fn = nn.CrossEntropyLoss()
|
|
254
291
|
>>>
|
|
255
|
-
>>> net_with_loss = nn.
|
|
292
|
+
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
|
|
256
293
|
>>> net_with_loss.set_train()
|
|
257
294
|
>>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
|
|
258
|
-
>>> tft_cb =
|
|
295
|
+
>>> tft_cb = TrainFaultTolerance()
|
|
259
296
|
>>> loss_cb = train.LossMonitor(1)
|
|
260
297
|
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
261
298
|
"""
|
|
262
299
|
|
|
263
|
-
def __init__(self,
|
|
264
|
-
super(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
if
|
|
268
|
-
raise ValueError("
|
|
269
|
-
mode = context.get_context("mode")
|
|
270
|
-
device_target = context.get_context("device_target")
|
|
271
|
-
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
272
|
-
raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
|
|
273
|
-
|
|
274
|
-
# let it raise errors if not install mindio_tft package
|
|
275
|
-
from mindio_ttp import framework_ttp as tft
|
|
276
|
-
self.tft = tft
|
|
277
|
-
self.global_step = 0
|
|
278
|
-
Validator.check_non_negative_int(ctrl_port)
|
|
279
|
-
self.has_init_replica = False
|
|
280
|
-
self.is_uce_rank = False
|
|
281
|
-
self._controller_ip = ctrl_ip
|
|
282
|
-
self._controller_rank_id = ctrl_rank_id
|
|
283
|
-
self._controller_port = ctrl_port
|
|
300
|
+
def __init__(self, ckpt_save_path=None, **kwargs):
|
|
301
|
+
super(TrainFaultTolerance, self).__init__()
|
|
302
|
+
self.save_cb = kwargs.get("ckpt_save_fn", None)
|
|
303
|
+
self.ckpt_save_path = ckpt_save_path
|
|
304
|
+
if self.save_cb is None and self.ckpt_save_path is None:
|
|
305
|
+
raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
|
|
284
306
|
self.cb_params = None
|
|
307
|
+
self.initial_step = kwargs.get("initial_step", 0)
|
|
285
308
|
self.device_id = context.get_context("device_id")
|
|
286
|
-
self.
|
|
287
|
-
self.
|
|
309
|
+
self.cur_step_num = 0
|
|
310
|
+
self.cur_epoch_num = 0
|
|
311
|
+
# For TREError(Training Result Error) scene, parameter `ckpt_load_fn` must be provided to load checkpoint
|
|
312
|
+
# from file for resuming training, the `ckpt_load_fn` is a function, prototype of which is:
|
|
313
|
+
# `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
|
|
314
|
+
# i.e. (param_dict, remove_redundancy)
|
|
315
|
+
self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
|
|
316
|
+
self.tft = _tft_handler.get_tft()
|
|
317
|
+
if self._only_enable_tre():
|
|
318
|
+
return
|
|
319
|
+
self._check_init()
|
|
320
|
+
self.global_step = None
|
|
321
|
+
self.learning_rate = None
|
|
322
|
+
self.has_init_replica = False
|
|
323
|
+
self.is_uce_rank = False
|
|
324
|
+
|
|
288
325
|
self.assign = mindspore.ops.Assign()
|
|
289
326
|
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
290
327
|
self.s1 = mindspore.hal.Stream()
|
|
291
328
|
_tft_sem_enable()
|
|
329
|
+
self._tft_register()
|
|
330
|
+
|
|
331
|
+
def _only_enable_tre(self):
|
|
332
|
+
"""Check if only configured MS_ENABLE_TFT='{TRE:1}'"""
|
|
333
|
+
env_enable = os.getenv("MS_ENABLE_TFT", "")
|
|
334
|
+
non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
335
|
+
if any(flag in env_enable for flag in non_tre_flags):
|
|
336
|
+
return False
|
|
337
|
+
return "TRE:1" in env_enable
|
|
338
|
+
|
|
339
|
+
def _check_init(self):
|
|
340
|
+
"""Check if the mindio-ttp had inited"""
|
|
341
|
+
if self.tft is None:
|
|
342
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
343
|
+
if "ARF:1" in tft_env:
|
|
344
|
+
raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
|
|
345
|
+
logger.warning(f"TFT handle not init, try to init")
|
|
346
|
+
_tft_handler.init(config=None)
|
|
347
|
+
self.tft = _tft_handler.get_tft()
|
|
348
|
+
logger.warning(f"TFT handle init ok.")
|
|
349
|
+
mode = context.get_context("mode")
|
|
350
|
+
device_target = context.get_context("device_target")
|
|
351
|
+
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
352
|
+
raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
|
|
353
|
+
f"device:{device_target}, run mode: {mode}")
|
|
292
354
|
|
|
293
355
|
def _is_params_consistent(self):
|
|
294
356
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
@@ -300,7 +362,7 @@ class TFTRegister(Callback):
|
|
|
300
362
|
return False
|
|
301
363
|
|
|
302
364
|
def _set_tft_optimizer_replica(self, run_context):
|
|
303
|
-
"""
|
|
365
|
+
""" Set Mindio TFT optimizer replica info, used internal. """
|
|
304
366
|
cur_rank = get_rank()
|
|
305
367
|
cb_params = run_context.original_args()
|
|
306
368
|
train_network = cb_params.train_network
|
|
@@ -322,33 +384,49 @@ class TFTRegister(Callback):
|
|
|
322
384
|
]
|
|
323
385
|
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
|
|
324
386
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
387
|
+
@classmethod
|
|
388
|
+
def get_optimizer_wrapper(cls, origin_opt_cls):
|
|
389
|
+
"""
|
|
390
|
+
Optimizer wrapper func when using tft.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
origin_opt_cls (Class): origin optimizer class.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
class TFTOptSubCls(origin_opt_cls):
|
|
397
|
+
"""
|
|
398
|
+
Optimizer wrapper class when using tft.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
def __init__(self, *args, **kwargs):
|
|
402
|
+
super(TFTOptSubCls, self).__init__(*args, **kwargs)
|
|
403
|
+
self.report = TensorReport()
|
|
404
|
+
self.report_end = TensorReport()
|
|
405
|
+
self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
|
|
406
|
+
self.depend = ops.Depend()
|
|
407
|
+
self.allreduce_sum = ops.AllReduce()
|
|
408
|
+
self.allreduce_sum.add_prim_attr("tft_report_before", True)
|
|
409
|
+
self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
|
|
410
|
+
|
|
411
|
+
def construct(self, gradients, **kwargs):
|
|
412
|
+
tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
|
|
413
|
+
self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
|
|
414
|
+
grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
|
|
415
|
+
opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
|
|
416
|
+
self.report_end("tft_report", self.tft_g_one_flag)
|
|
417
|
+
return opt_ret
|
|
418
|
+
|
|
419
|
+
return TFTOptSubCls
|
|
420
|
+
|
|
421
|
+
def _tft_register(self):
|
|
422
|
+
"""Register callback functions."""
|
|
328
423
|
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
|
|
329
424
|
self.tft.tft_register_rename_handler(_rename_save_result, self)
|
|
330
425
|
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
|
|
331
426
|
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
|
|
332
427
|
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
333
428
|
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
334
|
-
|
|
335
|
-
world_size = _get_device_num()
|
|
336
|
-
cur_rank = get_rank()
|
|
337
|
-
enable_local_copy = False
|
|
338
|
-
enable_arf = False
|
|
339
|
-
enable_tls = False
|
|
340
|
-
tls_key_dir = ""
|
|
341
|
-
|
|
342
|
-
if cur_rank == self._controller_rank_id:
|
|
343
|
-
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
344
|
-
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
|
|
345
|
-
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
346
|
-
logger.info("Finish start tft controller.")
|
|
347
|
-
|
|
348
|
-
logger.info("Begin to start tft processor.")
|
|
349
|
-
self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
|
|
350
|
-
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
351
|
-
logger.info("Finished start tft processor.")
|
|
429
|
+
self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
|
|
352
430
|
|
|
353
431
|
def _reset_acc_grads(self):
|
|
354
432
|
accu_grad_params = map(lambda e: e[1],
|
|
@@ -360,29 +438,44 @@ class TFTRegister(Callback):
|
|
|
360
438
|
|
|
361
439
|
def on_train_step_end(self, run_context):
|
|
362
440
|
"""
|
|
363
|
-
|
|
441
|
+
Report status to MindIO TFT after every step finished.
|
|
364
442
|
|
|
365
443
|
Args:
|
|
366
444
|
run_context (RunContext): Context of the train running. Refer to
|
|
367
445
|
:class:`mindspore.train.RunContext` for detail.
|
|
368
446
|
"""
|
|
447
|
+
if self._only_enable_tre():
|
|
448
|
+
return
|
|
369
449
|
if self.has_init_replica is False:
|
|
370
450
|
self.has_init_replica = True
|
|
371
451
|
self._set_tft_optimizer_replica(run_context)
|
|
372
452
|
cb_params = run_context.original_args()
|
|
373
453
|
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
454
|
+
self.cur_step_num = cb_params.cur_step_num
|
|
455
|
+
self.cur_epoch_num = cb_params.cur_epoch_num
|
|
374
456
|
if cb_params.optimizer is not None:
|
|
375
|
-
self.global_step =
|
|
457
|
+
self.global_step = cb_params.optimizer.global_step.clone()
|
|
376
458
|
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
377
|
-
|
|
378
|
-
self.global_step =
|
|
459
|
+
elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
460
|
+
self.global_step = cb_params.network.optimizer.global_step.clone()
|
|
379
461
|
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
|
|
380
|
-
|
|
462
|
+
else:
|
|
463
|
+
raise ValueError("TFT feature need optimizer or network's optimizer!")
|
|
464
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
|
|
381
465
|
logger.info("END Set optimizer finish step status to TFT.")
|
|
382
466
|
|
|
383
|
-
|
|
384
467
|
def on_train_begin(self, run_context):
|
|
468
|
+
"""
|
|
469
|
+
Register train params to MindIO TFT on train beginning.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
473
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
474
|
+
"""
|
|
385
475
|
cb_params = run_context.original_args()
|
|
476
|
+
if self._only_enable_tre():
|
|
477
|
+
self.cb_params = cb_params
|
|
478
|
+
return
|
|
386
479
|
sink_size = cb_params.get("sink_size", 0)
|
|
387
480
|
if sink_size > 1:
|
|
388
481
|
raise ValueError("TFT feature doesn't support sink_size > 1.")
|
|
@@ -391,7 +484,13 @@ class TFTRegister(Callback):
|
|
|
391
484
|
self.cb_params = cb_params
|
|
392
485
|
|
|
393
486
|
def end(self, run_context):
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
487
|
+
"""
|
|
488
|
+
Unregister MindIO TFT on train end.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
492
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
493
|
+
"""
|
|
494
|
+
if self._only_enable_tre():
|
|
495
|
+
return
|
|
496
|
+
_tft_handler.unregister_tft()
|
mindspore/train/data_sink.py
CHANGED
|
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
|
|
|
98
98
|
return next_op, (key, dataset_shapes, dataset_types)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
+
def _get_jit_func(sink_fun, jit_config):
|
|
102
|
+
"""
|
|
103
|
+
Get the jit function.
|
|
104
|
+
"""
|
|
105
|
+
jit_config_dict = jit_config.jit_config_dict
|
|
106
|
+
jit_level = jit_config_dict['jit_level']
|
|
107
|
+
if jit_level == "":
|
|
108
|
+
jit_level = "O0"
|
|
109
|
+
backend = ""
|
|
110
|
+
if jit_level == "O2":
|
|
111
|
+
jit_level = "O0"
|
|
112
|
+
backend = "GE"
|
|
113
|
+
if "backend" in jit_config_dict:
|
|
114
|
+
backend = jit_config_dict["backend"]
|
|
115
|
+
fullgraph = False
|
|
116
|
+
if jit_config_dict['jit_syntax_level'] == "STRICT":
|
|
117
|
+
fullgraph = True
|
|
118
|
+
exc_mode = jit_config_dict['exc_mode']
|
|
119
|
+
infer_boost = jit_config_dict['infer_boost']
|
|
120
|
+
return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
|
|
121
|
+
infer_boost=infer_boost)
|
|
122
|
+
|
|
123
|
+
|
|
101
124
|
def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
102
125
|
"""
|
|
103
126
|
get the sink function.
|
|
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
107
130
|
if jit_config is None:
|
|
108
131
|
dst_sink_fun = sink_fun
|
|
109
132
|
else:
|
|
110
|
-
dst_sink_fun =
|
|
133
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
111
134
|
dataset.__sink_fun__ = dst_sink_fun
|
|
112
135
|
|
|
113
136
|
return dataset.__sink_fun__
|
|
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
119
142
|
if jit_config is None:
|
|
120
143
|
dst_sink_fun = sink_fun
|
|
121
144
|
else:
|
|
122
|
-
dst_sink_fun =
|
|
145
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
123
146
|
dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
|
|
124
147
|
|
|
125
148
|
return dst_sink_fun
|
|
@@ -214,8 +214,7 @@ def _get_dataset_aux(dataset):
|
|
|
214
214
|
|
|
215
215
|
def connect_network_with_dataset(network, dataset_helper):
|
|
216
216
|
"""
|
|
217
|
-
Connect the `network` with dataset in `dataset_helper`. Only supported in
|
|
218
|
-
<https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
|
|
217
|
+
Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
|
|
219
218
|
(dataset_sink_mode=True).
|
|
220
219
|
|
|
221
220
|
Args:
|
|
@@ -335,11 +334,11 @@ class DatasetHelper:
|
|
|
335
334
|
dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
|
|
336
335
|
dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
|
|
337
336
|
Default: ``True``.
|
|
338
|
-
sink_size (int): Control the amount of data in each sink.
|
|
337
|
+
sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
|
|
339
338
|
If sink_size=-1, sink the complete dataset for each epoch.
|
|
340
339
|
If sink_size>0, sink sink_size data for each epoch.
|
|
341
|
-
Default:
|
|
342
|
-
epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1
|
|
340
|
+
Default: ``-1``.
|
|
341
|
+
epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
|
|
343
342
|
|
|
344
343
|
Examples:
|
|
345
344
|
>>> import numpy as np
|