mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -15,24 +15,27 @@
|
|
|
15
15
|
"""Checkpoint related classes and functions."""
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
|
+
from mindspore.utils import _tft_handler
|
|
18
19
|
from mindspore.train.serialization import save_checkpoint
|
|
19
|
-
from mindspore.parallel._utils import _get_device_num
|
|
20
|
-
from mindspore import _checkparam as Validator
|
|
21
20
|
from mindspore.train.callback._callback import Callback
|
|
22
|
-
from mindspore import context
|
|
21
|
+
from mindspore import context, ops
|
|
23
22
|
from mindspore.common.parameter import Parameter
|
|
24
23
|
from mindspore.common.tensor import Tensor
|
|
25
24
|
from mindspore.communication import get_rank, get_group_size
|
|
26
25
|
from mindspore import log as logger
|
|
27
26
|
from mindspore.train.serialization import _get_cur_rank_dp
|
|
28
|
-
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
|
|
27
|
+
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
|
|
28
|
+
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
|
-
from mindspore._c_expression import send_recv
|
|
30
|
+
from mindspore._c_expression import send_recv, reset_params
|
|
31
31
|
from mindspore._c_expression import CollectiveManager
|
|
32
32
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
33
|
-
from mindspore._c_expression import
|
|
33
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
34
|
+
from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
34
35
|
import mindspore
|
|
35
36
|
import mindspore.common.dtype as mstype
|
|
37
|
+
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
38
|
+
|
|
36
39
|
|
|
37
40
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
38
41
|
""" Common func to generate ckpt dir name."""
|
|
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
|
40
43
|
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
|
|
41
44
|
return os.path.join(ckpt_save_path, mid_dir)
|
|
42
45
|
|
|
46
|
+
|
|
43
47
|
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
44
48
|
""" Callback used for TFT save ckpt function when errors occur."""
|
|
45
49
|
logger.info("Enter _save_checkpoint_on_failure function")
|
|
46
|
-
if not cb_ctx._is_params_consistent():
|
|
50
|
+
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
47
51
|
raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
|
|
52
|
+
cb_params = args
|
|
53
|
+
# we record the current step and epoch num in on_train_step_end, so we can just reset it here
|
|
54
|
+
cb_params.cur_step_num = cb_ctx.cur_step_num
|
|
55
|
+
cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
|
|
56
|
+
if cb_params.optimizer is not None:
|
|
57
|
+
cb_params.optimizer.global_step = cb_ctx.global_step
|
|
58
|
+
if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
59
|
+
cb_params.network.optimizer.global_step = cb_ctx.global_step
|
|
60
|
+
append_dict = {}
|
|
61
|
+
append_dict["__exception_save__"] = True
|
|
62
|
+
# if user has provided a custom save callback, use it
|
|
63
|
+
if cb_ctx.save_cb:
|
|
64
|
+
cb_ctx.save_cb(cb_params, append_dict)
|
|
65
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
66
|
+
return
|
|
48
67
|
|
|
68
|
+
# if user has not provided a custom save callback, use default save logic
|
|
49
69
|
ckpt_save_path = cb_ctx.ckpt_save_path
|
|
50
|
-
cb_params = args
|
|
51
70
|
cur_rank = get_rank()
|
|
52
|
-
|
|
71
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
53
72
|
cur_epoch_num = cb_params.cur_epoch_num
|
|
54
|
-
batch_num = cb_params.batch_num
|
|
55
|
-
if cur_step_num > step:
|
|
56
|
-
cur_epoch_num = (step - 1) // batch_num + 1
|
|
57
|
-
step_num_in_epoch = int((step - 1) % batch_num + 1)
|
|
58
|
-
|
|
59
|
-
append_dict = {}
|
|
60
73
|
append_dict["epoch_num"] = cur_epoch_num
|
|
61
|
-
append_dict["step_num"] =
|
|
74
|
+
append_dict["step_num"] = cb_params.cur_step_num
|
|
62
75
|
append_dict["cur_rank"] = cur_rank
|
|
63
|
-
append_dict["batch_num"] = batch_num
|
|
64
|
-
append_dict["
|
|
65
|
-
|
|
66
|
-
append_dict["global_step"] = Parameter([cb_ctx.global_step])
|
|
76
|
+
append_dict["batch_num"] = cb_params.batch_num
|
|
77
|
+
append_dict["global_step"] = cb_ctx.global_step
|
|
67
78
|
outputs = cb_params.net_outputs
|
|
68
79
|
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
|
|
69
80
|
append_dict["loss_scale"] = outputs[2]
|
|
@@ -76,47 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
|
76
87
|
integrated_save=False, append_dict=append_dict)
|
|
77
88
|
logger.info("Finish _save_checkpoint_on_failure function")
|
|
78
89
|
|
|
90
|
+
|
|
79
91
|
def _rename_save_result(step, cb_ctx):
|
|
80
92
|
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
|
|
81
93
|
logger.info("Enter _rename_save_result function")
|
|
94
|
+
if cb_ctx.save_cb:
|
|
95
|
+
logger.info("User's save callback is provided, skip rename")
|
|
96
|
+
return
|
|
82
97
|
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
|
|
83
98
|
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
|
|
84
99
|
|
|
85
100
|
os.rename(tmp_dir, fin_dir)
|
|
86
101
|
logger.info("Finish _rename_save_result function")
|
|
87
102
|
|
|
103
|
+
|
|
88
104
|
def _tft_exit_cb(ctx):
|
|
105
|
+
"""Callback used for TFT exit function."""
|
|
89
106
|
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
|
|
90
107
|
_tft_sem_post()
|
|
91
|
-
os._exit(1)
|
|
108
|
+
os._exit(1) # pylint: disable=W0212
|
|
109
|
+
|
|
92
110
|
|
|
93
111
|
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
94
112
|
""" Callback used for TFT repair function."""
|
|
95
|
-
logger.
|
|
96
|
-
if(repair_info["repair_type"]
|
|
97
|
-
|
|
98
|
-
logger.
|
|
113
|
+
logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
114
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
115
|
+
cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
|
|
116
|
+
logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
99
117
|
_repair_device(cb_ctx.device_id)
|
|
100
118
|
|
|
101
|
-
if(repair_info["repair_type"]
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
{}, src_rank:{}, dst_rank: {}".format(
|
|
119
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
120
|
+
cb_ctx.tft.RepairType.RT_SEND.value,
|
|
121
|
+
cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
|
|
122
|
+
logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
|
|
123
|
+
repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
|
|
105
124
|
cb_params = args
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
125
|
+
if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
|
|
126
|
+
for i in range(len(repair_info["src"])):
|
|
127
|
+
src_rank = repair_info["src"][i]
|
|
128
|
+
dst_rank = repair_info["dst"][i]
|
|
129
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
130
|
+
raise ValueError("Call send_recv failed.")
|
|
131
|
+
else:
|
|
132
|
+
src_rank = repair_info["src"][0]
|
|
133
|
+
dst_rank = repair_info["dst"][0]
|
|
134
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
135
|
+
raise ValueError("Call send_recv failed.")
|
|
136
|
+
logger.warning("Finish _tft_repair_callback")
|
|
110
137
|
|
|
111
138
|
|
|
112
139
|
def _tft_clean_callback(is_uce_error, args, ctx):
|
|
113
140
|
""" Callback used for TFT clean function."""
|
|
114
|
-
logger.
|
|
141
|
+
logger.warning("Enter _tft_clean_callback")
|
|
115
142
|
ret = 0
|
|
116
143
|
if is_uce_error:
|
|
117
144
|
_get_uce_mem_info(ctx.device_id)
|
|
118
145
|
err_strategy = _get_uce_process_strategy()
|
|
119
|
-
logger.
|
|
146
|
+
logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
|
|
120
147
|
if err_strategy == "RS_UCE_HIGHLEVEL":
|
|
121
148
|
ret = 0
|
|
122
149
|
elif err_strategy == "RS_UCE_LOWLEVEL":
|
|
@@ -124,59 +151,81 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
124
151
|
else:
|
|
125
152
|
ret = 1
|
|
126
153
|
clean_tdt_channel()
|
|
127
|
-
logger.
|
|
154
|
+
logger.warning("Enter _tft_clean_callback resume_hccl_comm")
|
|
128
155
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
129
|
-
logger.
|
|
156
|
+
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
130
157
|
return ret
|
|
131
158
|
|
|
132
159
|
|
|
133
160
|
def _tft_stop_callback(args, cb_ctx):
|
|
134
161
|
""" Callback used for TFT stop function."""
|
|
135
|
-
logger.
|
|
162
|
+
logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
136
163
|
_stop_device(cb_ctx.device_id)
|
|
137
|
-
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()):
|
|
164
|
+
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
138
165
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
139
166
|
cb_ctx.is_uce_rank = False
|
|
167
|
+
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
168
|
+
logger.warning(f"Reset limit step")
|
|
169
|
+
cb_ctx.tft.tft_reset_limit_step()
|
|
140
170
|
logger.info("Finish _tft_stop_callback")
|
|
141
171
|
|
|
142
172
|
|
|
143
|
-
|
|
173
|
+
def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
174
|
+
"""Callback used for TFT Rebuild Group function."""
|
|
175
|
+
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
|
|
176
|
+
_finalize_comm()
|
|
177
|
+
_rebuild_world_group()
|
|
178
|
+
_rebuild_sub_group()
|
|
179
|
+
_set_recovery_context(is_arf=True)
|
|
180
|
+
logger.warning("Enter _tft_rebuild_sub_groups ok ")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TrainFaultTolerance(Callback):
|
|
144
184
|
"""
|
|
145
185
|
This callback is used to enable the TFT feature
|
|
146
|
-
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
147
|
-
|
|
186
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
187
|
+
and will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
148
188
|
|
|
149
189
|
Note:
|
|
150
190
|
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
|
|
151
191
|
|
|
152
192
|
Args:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
|
|
193
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
|
|
194
|
+
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
|
|
195
|
+
is created in that directory. Default: ``None``.
|
|
196
|
+
kwargs (dict): Other dictionary type parameters.
|
|
158
197
|
|
|
159
198
|
Raises:
|
|
160
199
|
Exception: TFT init failed.
|
|
161
200
|
ModuleNotFoundError: Mindio TFT whl package is not installed.
|
|
162
201
|
|
|
163
202
|
Examples:
|
|
203
|
+
.. note::
|
|
204
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
205
|
+
|
|
206
|
+
It's recommended to use the msrun startup method.
|
|
207
|
+
Please see the `msrun start up
|
|
208
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
209
|
+
for more details.
|
|
210
|
+
|
|
211
|
+
This example should be run with 4 devices.
|
|
212
|
+
|
|
164
213
|
>>> import numpy as np
|
|
165
214
|
>>> import os
|
|
166
215
|
>>> import math
|
|
167
216
|
>>> import mindspore as ms
|
|
168
217
|
>>> import mindspore.dataset as ds
|
|
169
218
|
>>> from mindspore import nn, ops, Parameter, train
|
|
170
|
-
>>> from mindspore.communication import init
|
|
219
|
+
>>> from mindspore.communication import init, get_rank
|
|
171
220
|
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
172
|
-
>>> from mindspore.train import Model,
|
|
221
|
+
>>> from mindspore.train import Model, TrainFaultTolerance
|
|
173
222
|
>>> from mindspore import dataset as ds
|
|
174
223
|
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
175
224
|
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
176
225
|
>>> init()
|
|
177
226
|
>>> ms.set_seed(1)
|
|
178
227
|
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
|
|
179
|
-
|
|
228
|
+
... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
|
|
180
229
|
>>> class MatMulCell(nn.Cell):
|
|
181
230
|
... def __init__(self, param=None, shape=None):
|
|
182
231
|
... super().__init__()
|
|
@@ -234,48 +283,74 @@ class TFTRegister(Callback):
|
|
|
234
283
|
... dataset = dataset.batch(batch_size)
|
|
235
284
|
... return dataset
|
|
236
285
|
>>>
|
|
237
|
-
>>>
|
|
286
|
+
>>> dataset = create_dataset(32)
|
|
238
287
|
>>>
|
|
239
288
|
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
240
289
|
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
241
290
|
>>> loss_fn = nn.CrossEntropyLoss()
|
|
242
291
|
>>>
|
|
243
|
-
>>> net_with_loss = nn.
|
|
292
|
+
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
|
|
244
293
|
>>> net_with_loss.set_train()
|
|
245
|
-
>>> model = Model(net_with_loss, optimizer=
|
|
246
|
-
>>> tft_cb =
|
|
294
|
+
>>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
|
|
295
|
+
>>> tft_cb = TrainFaultTolerance()
|
|
247
296
|
>>> loss_cb = train.LossMonitor(1)
|
|
248
297
|
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
249
298
|
"""
|
|
250
299
|
|
|
251
|
-
def __init__(self,
|
|
252
|
-
super(
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
if
|
|
256
|
-
raise ValueError("
|
|
257
|
-
mode = context.get_context("mode")
|
|
258
|
-
device_target = context.get_context("device_target")
|
|
259
|
-
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
260
|
-
raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
|
|
261
|
-
|
|
262
|
-
# let it raise errors if not install mindio_tft package
|
|
263
|
-
from mindio_ttp import framework_ttp as tft
|
|
264
|
-
self.tft = tft
|
|
265
|
-
self.is_uce_rank = False
|
|
266
|
-
self.global_step = 0
|
|
267
|
-
Validator.check_non_negative_int(ctrl_port)
|
|
268
|
-
self.has_init_replica = False
|
|
269
|
-
self._controller_ip = ctrl_ip
|
|
270
|
-
self._controller_rank_id = ctrl_rank_id
|
|
271
|
-
self._controller_port = ctrl_port
|
|
300
|
+
def __init__(self, ckpt_save_path=None, **kwargs):
|
|
301
|
+
super(TrainFaultTolerance, self).__init__()
|
|
302
|
+
self.save_cb = kwargs.get("ckpt_save_fn", None)
|
|
303
|
+
self.ckpt_save_path = ckpt_save_path
|
|
304
|
+
if self.save_cb is None and self.ckpt_save_path is None:
|
|
305
|
+
raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
|
|
272
306
|
self.cb_params = None
|
|
307
|
+
self.initial_step = kwargs.get("initial_step", 0)
|
|
273
308
|
self.device_id = context.get_context("device_id")
|
|
274
|
-
self.
|
|
275
|
-
self.
|
|
309
|
+
self.cur_step_num = 0
|
|
310
|
+
self.cur_epoch_num = 0
|
|
311
|
+
# For TREError(Training Result Error) scene, parameter `ckpt_load_fn` must be provided to load checkpoint
|
|
312
|
+
# from file for resuming training, the `ckpt_load_fn` is a function, prototype of which is:
|
|
313
|
+
# `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
|
|
314
|
+
# i.e. (param_dict, remove_redundancy)
|
|
315
|
+
self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
|
|
316
|
+
self.tft = _tft_handler.get_tft()
|
|
317
|
+
if self._only_enable_tre():
|
|
318
|
+
return
|
|
319
|
+
self._check_init()
|
|
320
|
+
self.global_step = None
|
|
321
|
+
self.learning_rate = None
|
|
322
|
+
self.has_init_replica = False
|
|
323
|
+
self.is_uce_rank = False
|
|
324
|
+
|
|
276
325
|
self.assign = mindspore.ops.Assign()
|
|
277
326
|
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
278
327
|
self.s1 = mindspore.hal.Stream()
|
|
328
|
+
_tft_sem_enable()
|
|
329
|
+
self._tft_register()
|
|
330
|
+
|
|
331
|
+
def _only_enable_tre(self):
|
|
332
|
+
"""Check if only configured MS_ENABLE_TFT='{TRE:1}'"""
|
|
333
|
+
env_enable = os.getenv("MS_ENABLE_TFT", "")
|
|
334
|
+
non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
335
|
+
if any(flag in env_enable for flag in non_tre_flags):
|
|
336
|
+
return False
|
|
337
|
+
return "TRE:1" in env_enable
|
|
338
|
+
|
|
339
|
+
def _check_init(self):
|
|
340
|
+
"""Check if the mindio-ttp had inited"""
|
|
341
|
+
if self.tft is None:
|
|
342
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
343
|
+
if "ARF:1" in tft_env:
|
|
344
|
+
raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
|
|
345
|
+
logger.warning(f"TFT handle not init, try to init")
|
|
346
|
+
_tft_handler.init(config=None)
|
|
347
|
+
self.tft = _tft_handler.get_tft()
|
|
348
|
+
logger.warning(f"TFT handle init ok.")
|
|
349
|
+
mode = context.get_context("mode")
|
|
350
|
+
device_target = context.get_context("device_target")
|
|
351
|
+
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
352
|
+
raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
|
|
353
|
+
f"device:{device_target}, run mode: {mode}")
|
|
279
354
|
|
|
280
355
|
def _is_params_consistent(self):
|
|
281
356
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
@@ -287,7 +362,7 @@ class TFTRegister(Callback):
|
|
|
287
362
|
return False
|
|
288
363
|
|
|
289
364
|
def _set_tft_optimizer_replica(self, run_context):
|
|
290
|
-
"""
|
|
365
|
+
""" Set Mindio TFT optimizer replica info, used internal. """
|
|
291
366
|
cur_rank = get_rank()
|
|
292
367
|
cb_params = run_context.original_args()
|
|
293
368
|
train_network = cb_params.train_network
|
|
@@ -309,59 +384,98 @@ class TFTRegister(Callback):
|
|
|
309
384
|
]
|
|
310
385
|
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
|
|
311
386
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
387
|
+
@classmethod
|
|
388
|
+
def get_optimizer_wrapper(cls, origin_opt_cls):
|
|
389
|
+
"""
|
|
390
|
+
Optimizer wrapper func when using tft.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
origin_opt_cls (Class): origin optimizer class.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
class TFTOptSubCls(origin_opt_cls):
|
|
397
|
+
"""
|
|
398
|
+
Optimizer wrapper class when using tft.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
def __init__(self, *args, **kwargs):
|
|
402
|
+
super(TFTOptSubCls, self).__init__(*args, **kwargs)
|
|
403
|
+
self.report = TensorReport()
|
|
404
|
+
self.report_end = TensorReport()
|
|
405
|
+
self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
|
|
406
|
+
self.depend = ops.Depend()
|
|
407
|
+
self.allreduce_sum = ops.AllReduce()
|
|
408
|
+
self.allreduce_sum.add_prim_attr("tft_report_before", True)
|
|
409
|
+
self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
|
|
410
|
+
|
|
411
|
+
def construct(self, gradients, **kwargs):
|
|
412
|
+
tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
|
|
413
|
+
self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
|
|
414
|
+
grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
|
|
415
|
+
opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
|
|
416
|
+
self.report_end("tft_report", self.tft_g_one_flag)
|
|
417
|
+
return opt_ret
|
|
418
|
+
|
|
419
|
+
return TFTOptSubCls
|
|
420
|
+
|
|
421
|
+
def _tft_register(self):
|
|
422
|
+
"""Register callback functions."""
|
|
315
423
|
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
|
|
316
424
|
self.tft.tft_register_rename_handler(_rename_save_result, self)
|
|
317
425
|
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
|
|
318
426
|
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
|
|
319
427
|
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
320
428
|
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
429
|
+
self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
|
|
321
430
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
if cur_rank == self._controller_rank_id:
|
|
330
|
-
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
331
|
-
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
|
|
332
|
-
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
333
|
-
logger.info("Finish start tft controller.")
|
|
334
|
-
|
|
335
|
-
logger.info("Begin to start tft processor.")
|
|
336
|
-
self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
|
|
337
|
-
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
338
|
-
logger.info("Finished start tft processor.")
|
|
431
|
+
def _reset_acc_grads(self):
|
|
432
|
+
accu_grad_params = map(lambda e: e[1],
|
|
433
|
+
filter(lambda e: e[1].name.startswith('accu_grads'),
|
|
434
|
+
self.cb_params.train_network.parameters_and_names()))
|
|
435
|
+
accu_grad_list = list(accu_grad_params)
|
|
436
|
+
if reset_params(accu_grad_list) != 0:
|
|
437
|
+
raise ValueError("Call reset_params failed.")
|
|
339
438
|
|
|
340
439
|
def on_train_step_end(self, run_context):
|
|
341
440
|
"""
|
|
342
|
-
|
|
441
|
+
Report status to MindIO TFT after every step finished.
|
|
343
442
|
|
|
344
443
|
Args:
|
|
345
444
|
run_context (RunContext): Context of the train running. Refer to
|
|
346
445
|
:class:`mindspore.train.RunContext` for detail.
|
|
347
446
|
"""
|
|
447
|
+
if self._only_enable_tre():
|
|
448
|
+
return
|
|
348
449
|
if self.has_init_replica is False:
|
|
349
450
|
self.has_init_replica = True
|
|
350
451
|
self._set_tft_optimizer_replica(run_context)
|
|
351
452
|
cb_params = run_context.original_args()
|
|
352
453
|
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
353
|
-
self.
|
|
454
|
+
self.cur_step_num = cb_params.cur_step_num
|
|
455
|
+
self.cur_epoch_num = cb_params.cur_epoch_num
|
|
354
456
|
if cb_params.optimizer is not None:
|
|
355
|
-
self.global_step =
|
|
457
|
+
self.global_step = cb_params.optimizer.global_step.clone()
|
|
356
458
|
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
357
|
-
|
|
358
|
-
self.global_step =
|
|
459
|
+
elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
460
|
+
self.global_step = cb_params.network.optimizer.global_step.clone()
|
|
359
461
|
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
|
|
462
|
+
else:
|
|
463
|
+
raise ValueError("TFT feature need optimizer or network's optimizer!")
|
|
464
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
|
|
360
465
|
logger.info("END Set optimizer finish step status to TFT.")
|
|
361
466
|
|
|
362
|
-
|
|
363
467
|
def on_train_begin(self, run_context):
|
|
468
|
+
"""
|
|
469
|
+
Register train params to MindIO TFT on train beginning.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
473
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
474
|
+
"""
|
|
364
475
|
cb_params = run_context.original_args()
|
|
476
|
+
if self._only_enable_tre():
|
|
477
|
+
self.cb_params = cb_params
|
|
478
|
+
return
|
|
365
479
|
sink_size = cb_params.get("sink_size", 0)
|
|
366
480
|
if sink_size > 1:
|
|
367
481
|
raise ValueError("TFT feature doesn't support sink_size > 1.")
|
|
@@ -370,7 +484,13 @@ class TFTRegister(Callback):
|
|
|
370
484
|
self.cb_params = cb_params
|
|
371
485
|
|
|
372
486
|
def end(self, run_context):
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
487
|
+
"""
|
|
488
|
+
Unregister MindIO TFT on train end.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
492
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
493
|
+
"""
|
|
494
|
+
if self._only_enable_tre():
|
|
495
|
+
return
|
|
496
|
+
_tft_handler.unregister_tft()
|
mindspore/train/data_sink.py
CHANGED
|
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
|
|
|
98
98
|
return next_op, (key, dataset_shapes, dataset_types)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
+
def _get_jit_func(sink_fun, jit_config):
|
|
102
|
+
"""
|
|
103
|
+
Get the jit function.
|
|
104
|
+
"""
|
|
105
|
+
jit_config_dict = jit_config.jit_config_dict
|
|
106
|
+
jit_level = jit_config_dict['jit_level']
|
|
107
|
+
if jit_level == "":
|
|
108
|
+
jit_level = "O0"
|
|
109
|
+
backend = ""
|
|
110
|
+
if jit_level == "O2":
|
|
111
|
+
jit_level = "O0"
|
|
112
|
+
backend = "GE"
|
|
113
|
+
if "backend" in jit_config_dict:
|
|
114
|
+
backend = jit_config_dict["backend"]
|
|
115
|
+
fullgraph = False
|
|
116
|
+
if jit_config_dict['jit_syntax_level'] == "STRICT":
|
|
117
|
+
fullgraph = True
|
|
118
|
+
exc_mode = jit_config_dict['exc_mode']
|
|
119
|
+
infer_boost = jit_config_dict['infer_boost']
|
|
120
|
+
return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
|
|
121
|
+
infer_boost=infer_boost)
|
|
122
|
+
|
|
123
|
+
|
|
101
124
|
def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
102
125
|
"""
|
|
103
126
|
get the sink function.
|
|
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
107
130
|
if jit_config is None:
|
|
108
131
|
dst_sink_fun = sink_fun
|
|
109
132
|
else:
|
|
110
|
-
dst_sink_fun =
|
|
133
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
111
134
|
dataset.__sink_fun__ = dst_sink_fun
|
|
112
135
|
|
|
113
136
|
return dataset.__sink_fun__
|
|
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
119
142
|
if jit_config is None:
|
|
120
143
|
dst_sink_fun = sink_fun
|
|
121
144
|
else:
|
|
122
|
-
dst_sink_fun =
|
|
145
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
123
146
|
dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
|
|
124
147
|
|
|
125
148
|
return dst_sink_fun
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""Dataset help for minddata dataset"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import math
|
|
20
19
|
import copy
|
|
21
20
|
|
|
@@ -25,9 +24,11 @@ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
|
25
24
|
from mindspore.common.dtype import pytype_to_dtype
|
|
26
25
|
from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
|
|
27
26
|
from mindspore.common._utils import is_shape_unknown
|
|
27
|
+
from mindspore.dataset.core import config as dataset_config
|
|
28
28
|
from mindspore.dataset.engine import offload
|
|
29
29
|
from mindspore import context, nn
|
|
30
|
-
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes,
|
|
30
|
+
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
|
31
|
+
_construct_tensor_list, enable_data_broadcast
|
|
31
32
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
|
|
32
33
|
_to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
|
|
33
34
|
_origin_shapes, _dynamic_shape_for_dataset
|
|
@@ -213,8 +214,7 @@ def _get_dataset_aux(dataset):
|
|
|
213
214
|
|
|
214
215
|
def connect_network_with_dataset(network, dataset_helper):
|
|
215
216
|
"""
|
|
216
|
-
Connect the `network` with dataset in `dataset_helper`. Only supported in
|
|
217
|
-
<https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
|
|
217
|
+
Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
|
|
218
218
|
(dataset_sink_mode=True).
|
|
219
219
|
|
|
220
220
|
Args:
|
|
@@ -263,16 +263,14 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
263
263
|
"The dataset has been connected to other network, please check the code.")
|
|
264
264
|
is_dynamic = bool(network.get_inputs())
|
|
265
265
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
266
|
+
|
|
266
267
|
# In pipeline parallel, some stages have no GetNext, should not get in.
|
|
268
|
+
# Don't enable dynamic shape(multi-subgraph) feature in pp/dataset_broadcast mode,
|
|
269
|
+
# otherwise get_data_info will stuck since some rank do not consume data.
|
|
267
270
|
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
271
|
+
data_broadcast = enable_data_broadcast()
|
|
268
272
|
|
|
269
|
-
|
|
270
|
-
dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
|
|
271
|
-
dynamic_sink1 = True
|
|
272
|
-
if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
|
|
273
|
-
dynamic_sink1 = False
|
|
274
|
-
|
|
275
|
-
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
|
|
273
|
+
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and not data_broadcast:
|
|
276
274
|
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
|
277
275
|
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
278
276
|
if _need_to_full():
|
|
@@ -314,7 +312,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
314
312
|
aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
|
|
315
313
|
|
|
316
314
|
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
|
|
317
|
-
not use_pipeline_parallel and
|
|
315
|
+
not use_pipeline_parallel and not data_broadcast:
|
|
318
316
|
dataset_helper.get_data_info()
|
|
319
317
|
network.add_flags(sink_mode=True)
|
|
320
318
|
return network
|
|
@@ -336,11 +334,11 @@ class DatasetHelper:
|
|
|
336
334
|
dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
|
|
337
335
|
dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
|
|
338
336
|
Default: ``True``.
|
|
339
|
-
sink_size (int): Control the amount of data in each sink.
|
|
337
|
+
sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
|
|
340
338
|
If sink_size=-1, sink the complete dataset for each epoch.
|
|
341
339
|
If sink_size>0, sink sink_size data for each epoch.
|
|
342
|
-
Default:
|
|
343
|
-
epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1
|
|
340
|
+
Default: ``-1``.
|
|
341
|
+
epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
|
|
344
342
|
|
|
345
343
|
Examples:
|
|
346
344
|
>>> import numpy as np
|
|
@@ -686,8 +684,9 @@ class _DatasetIterNormal:
|
|
|
686
684
|
self.dataset = dataset
|
|
687
685
|
self.device_num = _get_device_num()
|
|
688
686
|
self.global_rank = _get_global_rank()
|
|
687
|
+
do_copy = dataset_config.get_iterator_mode()["do_copy"]
|
|
689
688
|
self.iter = self.dataset.create_tuple_iterator(
|
|
690
|
-
num_epochs=epoch_num, do_copy=
|
|
689
|
+
num_epochs=epoch_num, do_copy=do_copy)
|
|
691
690
|
|
|
692
691
|
def __iter__(self):
|
|
693
692
|
return self
|