mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -27,7 +27,6 @@ import time
|
|
|
27
27
|
import numpy as np
|
|
28
28
|
|
|
29
29
|
import mindspore
|
|
30
|
-
import mindspore.dataset as ds
|
|
31
30
|
from mindspore import log as logger
|
|
32
31
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
33
32
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
35
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
36
|
from mindspore import _checkparam as Validator
|
|
38
37
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
|
|
38
|
+
TrainFaultTolerance
|
|
40
39
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
40
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
41
|
from mindspore import context
|
|
@@ -46,7 +45,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
|
|
|
46
45
|
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
47
46
|
_cache_enable, _enable_distributed_mindrt
|
|
48
47
|
from mindspore.train.metrics import Loss
|
|
49
|
-
from mindspore.
|
|
48
|
+
from mindspore.log import vlog_print
|
|
50
49
|
from mindspore import nn
|
|
51
50
|
from mindspore.boost import AutoBoost
|
|
52
51
|
from mindspore.context import ParallelMode
|
|
@@ -57,7 +56,11 @@ from mindspore.dataset.core.config import get_debug_mode
|
|
|
57
56
|
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
|
58
57
|
from mindspore.train import amp
|
|
59
58
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
|
+
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
|
+
from mindspore._c_expression import clean_tdt_channel
|
|
60
61
|
|
|
62
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
63
|
+
from .serialization import load_param_into_net
|
|
61
64
|
|
|
62
65
|
def _transfer_tensor_to_tuple(inputs):
|
|
63
66
|
"""
|
|
@@ -91,6 +94,7 @@ def _save_final_ckpt(func):
|
|
|
91
94
|
"""
|
|
92
95
|
Decorator function, which saves the current checkpoint when an exception occurs during training.
|
|
93
96
|
"""
|
|
97
|
+
|
|
94
98
|
@wraps(func)
|
|
95
99
|
def wrapper(self, *args, **kwargs):
|
|
96
100
|
obj = None
|
|
@@ -107,7 +111,7 @@ def _save_final_ckpt(func):
|
|
|
107
111
|
# pylint: disable=W0212
|
|
108
112
|
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
|
|
109
113
|
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
|
|
110
|
-
|
|
114
|
+
+ str(self._current_step_num) + "_breakpoint.ckpt"
|
|
111
115
|
cur_file = os.path.join(obj._directory, cur_ckpoint_file)
|
|
112
116
|
if "epoch_num" in obj._append_dict:
|
|
113
117
|
obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
|
|
@@ -118,85 +122,172 @@ def _save_final_ckpt(func):
|
|
|
118
122
|
raise e
|
|
119
123
|
else:
|
|
120
124
|
func(self, *args, **kwargs)
|
|
125
|
+
|
|
121
126
|
return wrapper
|
|
122
127
|
|
|
128
|
+
|
|
129
|
+
def _handle_exception_info(obj, uce_env, tft, e):
|
|
130
|
+
"""handle exception info"""
|
|
131
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
132
|
+
if not uce_env:
|
|
133
|
+
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
134
|
+
exc_info=True)
|
|
135
|
+
if tft:
|
|
136
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
137
|
+
raise e
|
|
138
|
+
e_str = str(e)
|
|
139
|
+
logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
140
|
+
if "UCEError" in e_str:
|
|
141
|
+
logger.info("uce wrapper report UCEError")
|
|
142
|
+
obj.is_uce_rank = True
|
|
143
|
+
# if error is HBM_MULTI_BIT_ECC_ERROR
|
|
144
|
+
if "error_code=507054" in e_str:
|
|
145
|
+
hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
|
|
146
|
+
can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
|
|
147
|
+
logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
|
|
148
|
+
hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
|
|
149
|
+
optimizer_end={optimizer_end}, can_repair={can_repair}")
|
|
150
|
+
if not can_repair:
|
|
151
|
+
logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
|
|
152
|
+
f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
|
|
153
|
+
f"optimizer_end={optimizer_end}", exc_info=True)
|
|
154
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
155
|
+
raise e
|
|
156
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
157
|
+
elif "ForceStopError" in e_str:
|
|
158
|
+
logger.warning("uce wrapper caught RuntimeError ForceStopError")
|
|
159
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
160
|
+
tft.tft_report_error(force_stop_err)
|
|
161
|
+
elif "ARF FINISH" in e_str:
|
|
162
|
+
logger.warning(f"ARF FINISH")
|
|
163
|
+
_set_recovery_context(is_arf=True)
|
|
164
|
+
tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
|
|
165
|
+
else:
|
|
166
|
+
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
167
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
168
|
+
raise e
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _handle_training_result_error(model, tft_obj):
|
|
172
|
+
"""
|
|
173
|
+
Handle training result error for resuming training.
|
|
174
|
+
"""
|
|
175
|
+
ckpt_load_fn = tft_obj.ckpt_load_func
|
|
176
|
+
train_network = tft_obj.cb_params.train_network
|
|
177
|
+
logger.warning("Process training result error start.")
|
|
178
|
+
# 1. Clear tdt channel
|
|
179
|
+
logger.warning("Clean tdt channel.")
|
|
180
|
+
clean_tdt_channel()
|
|
181
|
+
|
|
182
|
+
# 2. Load checkpoint
|
|
183
|
+
logger.warning("Load checkpoint.")
|
|
184
|
+
new_param_dict, remove_redundancy = ckpt_load_fn()
|
|
185
|
+
param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
|
|
186
|
+
logger.warning(f"param_not_load: {param_not_load}")
|
|
187
|
+
logger.warning(f"ckpt_not_load: {ckpt_not_load}")
|
|
188
|
+
resume_epoch = new_param_dict.get('epoch_num')
|
|
189
|
+
resume_step = new_param_dict.get('step_num')
|
|
190
|
+
model._initial_step = int(resume_step.asnumpy())
|
|
191
|
+
logger.warning("Process training result error end.")
|
|
192
|
+
return (resume_epoch, resume_step)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
|
|
196
|
+
"""calculate initial step for callback"""
|
|
197
|
+
train_dataset = args[1]
|
|
198
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
199
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
200
|
+
|
|
201
|
+
cb_initial_step = 0
|
|
202
|
+
if dataset_sink_mode:
|
|
203
|
+
train_dataset.set_init_step(org_epoch)
|
|
204
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
205
|
+
if sink_size != -1:
|
|
206
|
+
cb_initial_step = org_epoch * sink_size + org_step
|
|
207
|
+
else:
|
|
208
|
+
cb_initial_step = org_epoch * dataset_size + org_step
|
|
209
|
+
else:
|
|
210
|
+
train_dataset.set_init_step(org_step)
|
|
211
|
+
cb_initial_step = org_step
|
|
212
|
+
if hasattr(train_dataset, '_dataset_helper'):
|
|
213
|
+
dataset_helper = train_dataset._dataset_helper
|
|
214
|
+
_reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
|
|
215
|
+
return cb_initial_step
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _update_ckpt_callback_info(resume_train_step, **kwargs):
|
|
219
|
+
"""
|
|
220
|
+
Update checkpoint callback internal state
|
|
221
|
+
"""
|
|
222
|
+
ckpt_obj = None
|
|
223
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
|
224
|
+
ckpt_obj = kwargs.get('callbacks')
|
|
225
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
226
|
+
for item in kwargs.get('callbacks'):
|
|
227
|
+
if isinstance(item, ModelCheckpoint):
|
|
228
|
+
ckpt_obj = item
|
|
229
|
+
if ckpt_obj is not None:
|
|
230
|
+
ckpt_obj._last_triggered_step = 0
|
|
231
|
+
ckpt_obj._append_step_num = resume_train_step
|
|
232
|
+
|
|
233
|
+
|
|
123
234
|
def _handle_tft(func):
|
|
124
235
|
"""
|
|
125
236
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
126
237
|
"""
|
|
238
|
+
|
|
127
239
|
@wraps(func)
|
|
128
240
|
def wrapper(self, *args, **kwargs):
|
|
129
241
|
obj = None
|
|
130
|
-
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'),
|
|
242
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
|
|
131
243
|
obj = kwargs.get('callbacks')
|
|
132
244
|
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
133
245
|
for item in kwargs.get('callbacks'):
|
|
134
|
-
if isinstance(item,
|
|
246
|
+
if isinstance(item, TrainFaultTolerance):
|
|
135
247
|
obj = item
|
|
136
248
|
if obj:
|
|
137
249
|
tft = obj.tft
|
|
138
250
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
139
|
-
uce_env = "UCE:1" in tft_env
|
|
251
|
+
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
|
|
252
|
+
tre_env = "TRE:1" in tft_env
|
|
140
253
|
while True:
|
|
141
254
|
try:
|
|
142
255
|
return func(self, *args, **kwargs)
|
|
143
256
|
except RuntimeError as e:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
e_str = str(e)
|
|
150
|
-
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
151
|
-
if "UCEError" in e_str:
|
|
152
|
-
obj.is_uce_rank = True
|
|
153
|
-
logger.info("uce wrapper report UCEError")
|
|
154
|
-
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
155
|
-
elif "ForceStopError" in e_str:
|
|
156
|
-
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
157
|
-
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
158
|
-
tft.tft_report_error(force_stop_err)
|
|
257
|
+
if tre_env and 'TREError' in str(e):
|
|
258
|
+
_, resume_step = _handle_training_result_error(self, obj)
|
|
259
|
+
repair_step = int(resume_step.asnumpy())
|
|
260
|
+
_update_ckpt_callback_info(repair_step, **kwargs)
|
|
261
|
+
logger.warning(f'Resume training after TREError from step {repair_step}.')
|
|
159
262
|
else:
|
|
160
|
-
|
|
161
|
-
tft.
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
initial_epoch = int(repair_step/self.batch_num)
|
|
263
|
+
_handle_exception_info(obj, uce_env, tft, e)
|
|
264
|
+
ret = tft.tft_wait_next_action()
|
|
265
|
+
if ret == tft.Action.EXIT.value:
|
|
266
|
+
raise e
|
|
267
|
+
repair_step = tft.tft_get_repair_step()
|
|
268
|
+
logger.warning(
|
|
269
|
+
"uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
|
|
270
|
+
self.batch_num))
|
|
271
|
+
initial_epoch = int(repair_step / self.batch_num)
|
|
170
272
|
initial_step = repair_step % self.batch_num
|
|
171
273
|
kwargs["initial_epoch"] = initial_epoch
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
dataset_size = train_dataset.get_dataset_size()
|
|
181
|
-
if sink_size != -1:
|
|
182
|
-
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
183
|
-
else:
|
|
184
|
-
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
185
|
-
else:
|
|
186
|
-
train_dataset.set_init_step(initial_step)
|
|
187
|
-
cb_initial_step = initial_step
|
|
188
|
-
|
|
189
|
-
kwargs["initial_step"] = cb_initial_step
|
|
190
|
-
|
|
191
|
-
logger.info("uce wrapper repair complete \
|
|
192
|
-
initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
274
|
+
cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
|
|
275
|
+
if not self.enable_tre:
|
|
276
|
+
kwargs["initial_step"] = cb_initial_step
|
|
277
|
+
# reset all accu grads to zero
|
|
278
|
+
obj._reset_acc_grads()
|
|
279
|
+
logger.warning(
|
|
280
|
+
"uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
|
|
281
|
+
cb_initial_step))
|
|
193
282
|
continue
|
|
194
283
|
except BaseException as e:
|
|
195
|
-
|
|
196
|
-
|
|
284
|
+
if tft:
|
|
285
|
+
logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
|
|
286
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
197
287
|
raise e
|
|
198
288
|
else:
|
|
199
289
|
return func(self, *args, **kwargs)
|
|
290
|
+
|
|
200
291
|
return wrapper
|
|
201
292
|
|
|
202
293
|
|
|
@@ -213,7 +304,7 @@ def _check_tft():
|
|
|
213
304
|
if ms_mode != mindspore.GRAPH_MODE:
|
|
214
305
|
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
215
306
|
jit_level = context.get_context("jit_level")
|
|
216
|
-
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
307
|
+
if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
|
|
217
308
|
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
218
309
|
|
|
219
310
|
|
|
@@ -403,12 +494,13 @@ class Model:
|
|
|
403
494
|
the accuracy is reduced by less than 3%.
|
|
404
495
|
|
|
405
496
|
If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
|
|
406
|
-
In order for this function to work, you need to set the optimizer
|
|
407
|
-
at the
|
|
497
|
+
In order for this function to work, you need to set the parameter `optimizer`, along with
|
|
498
|
+
at least one of the parameter `eval_network` or performance `metrics`.
|
|
408
499
|
|
|
409
500
|
Notice: The current optimization enabled by default only applies to some networks, and not all networks
|
|
410
501
|
can obtain the same benefits. It is recommended to enable this function on
|
|
411
|
-
the Graph mode + Ascend platform, and for better acceleration,
|
|
502
|
+
the Graph mode + Ascend platform, and for better acceleration,
|
|
503
|
+
refer to :class:`mindspore.boost.AutoBoost` to configure
|
|
412
504
|
boost_config_dict.
|
|
413
505
|
|
|
414
506
|
Examples:
|
|
@@ -433,6 +525,7 @@ class Model:
|
|
|
433
525
|
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
|
|
434
526
|
amp_level="O0", boost_level="O0", **kwargs):
|
|
435
527
|
self._network = network
|
|
528
|
+
_init_auto_parallel_context(self._network)
|
|
436
529
|
self._loss_fn = loss_fn
|
|
437
530
|
self._optimizer = optimizer
|
|
438
531
|
self._loss_scale_manager = None
|
|
@@ -467,6 +560,9 @@ class Model:
|
|
|
467
560
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
468
561
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
469
562
|
self.batch_num = -1
|
|
563
|
+
self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
|
|
564
|
+
self._initial_step = None
|
|
565
|
+
_clear_auto_parallel_context(self._network)
|
|
470
566
|
|
|
471
567
|
def _check_for_graph_cell(self, kwargs):
|
|
472
568
|
"""Check for graph cell"""
|
|
@@ -665,7 +761,7 @@ class Model:
|
|
|
665
761
|
logger.info("Begin to connect network with dataset.")
|
|
666
762
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
667
763
|
|
|
668
|
-
if _get_recovery_context("enable_recovery") and is_train:
|
|
764
|
+
if (_get_recovery_context("enable_recovery") or self.enable_tre) and is_train:
|
|
669
765
|
_set_training_dataset(dataset_helper)
|
|
670
766
|
|
|
671
767
|
network.set_train(is_train)
|
|
@@ -762,7 +858,7 @@ class Model:
|
|
|
762
858
|
break
|
|
763
859
|
logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
|
|
764
860
|
|
|
765
|
-
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
861
|
+
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
766
862
|
"""
|
|
767
863
|
Initialize compute graphs and data graphs with the sink mode.
|
|
768
864
|
|
|
@@ -791,7 +887,6 @@ class Model:
|
|
|
791
887
|
if not isinstance(train_dataset, mindspore.dataset.Dataset):
|
|
792
888
|
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
793
889
|
"but got {}.".format(type(train_dataset)))
|
|
794
|
-
|
|
795
890
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
796
891
|
"Begin to check parameter broadcast in model.build().")
|
|
797
892
|
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
@@ -804,23 +899,24 @@ class Model:
|
|
|
804
899
|
train_dataset.__no_send__ = True
|
|
805
900
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
806
901
|
dataset=train_dataset,
|
|
807
|
-
dataset_sink_mode=
|
|
902
|
+
dataset_sink_mode=sink_mode,
|
|
808
903
|
sink_size=sink_size)
|
|
809
904
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
|
|
810
|
-
|
|
811
|
-
|
|
905
|
+
if sink_mode:
|
|
906
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
907
|
+
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
812
908
|
|
|
813
|
-
|
|
814
|
-
|
|
909
|
+
# Since dataset pipeline has been triggered, delete flag
|
|
910
|
+
delattr(train_dataset, "__no_send__")
|
|
815
911
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
912
|
+
# Waiting for the dataset warmup ready
|
|
913
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
914
|
+
"Begin waiting for dataset warmup in model.build().")
|
|
915
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
916
|
+
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
917
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
918
|
+
"The dataset warmup was successful in model.build().")
|
|
919
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
824
920
|
|
|
825
921
|
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
826
922
|
train_network.add_flags_recursive(is_first_iteration=True)
|
|
@@ -830,6 +926,7 @@ class Model:
|
|
|
830
926
|
logger.info("Begin to compile train network in model.build() procedure.")
|
|
831
927
|
train_network.compile(*inputs)
|
|
832
928
|
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
929
|
+
train_dataset.reset()
|
|
833
930
|
break
|
|
834
931
|
|
|
835
932
|
if valid_dataset:
|
|
@@ -843,7 +940,7 @@ class Model:
|
|
|
843
940
|
valid_dataset.__no_send__ = True
|
|
844
941
|
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
|
845
942
|
dataset=valid_dataset,
|
|
846
|
-
dataset_sink_mode=
|
|
943
|
+
dataset_sink_mode=sink_mode)
|
|
847
944
|
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
848
945
|
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
849
946
|
for inputs in valid_dataset_helper:
|
|
@@ -851,6 +948,7 @@ class Model:
|
|
|
851
948
|
"Begin to compile eval network in model.build().")
|
|
852
949
|
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
853
950
|
eval_network.compile(*inputs)
|
|
951
|
+
valid_dataset.reset()
|
|
854
952
|
break
|
|
855
953
|
|
|
856
954
|
@staticmethod
|
|
@@ -908,10 +1006,6 @@ class Model:
|
|
|
908
1006
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
909
1007
|
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
|
|
910
1008
|
cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
|
|
911
|
-
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
912
|
-
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
913
|
-
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
914
|
-
cb_params.batch_num, full_flops=False))
|
|
915
1009
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
916
1010
|
cb_params.list_callback.insert(0, _StepSync())
|
|
917
1011
|
callbacks = cb_params.list_callback
|
|
@@ -923,6 +1017,8 @@ class Model:
|
|
|
923
1017
|
cb_params.last_save_ckpt_step = None
|
|
924
1018
|
cb_params.latest_ckpt_file = None
|
|
925
1019
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
1020
|
+
cb_params.is_arf = _get_recovery_context("is_arf")
|
|
1021
|
+
cb_params.initial_step = self._initial_step
|
|
926
1022
|
|
|
927
1023
|
# build callback list
|
|
928
1024
|
with _CallbackManager(callbacks) as list_callback:
|
|
@@ -1027,6 +1123,9 @@ class Model:
|
|
|
1027
1123
|
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1028
1124
|
if need_exec_callback_step_end:
|
|
1029
1125
|
list_callback.on_train_step_end(run_context)
|
|
1126
|
+
if cb_params.is_arf:
|
|
1127
|
+
cb_params.is_arf = False
|
|
1128
|
+
_set_recovery_context(is_arf=False)
|
|
1030
1129
|
|
|
1031
1130
|
# Embedding cache server only run one step.
|
|
1032
1131
|
if is_embedding_cache_server:
|
|
@@ -1057,7 +1156,7 @@ class Model:
|
|
|
1057
1156
|
if should_stop:
|
|
1058
1157
|
break
|
|
1059
1158
|
|
|
1060
|
-
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
|
|
1159
|
+
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
|
|
1061
1160
|
and not _get_recovery_context("latest_ckpt_file")
|
|
1062
1161
|
self.epoch_iter += 1
|
|
1063
1162
|
if need_reset_to_beginning:
|
|
@@ -1101,7 +1200,7 @@ class Model:
|
|
|
1101
1200
|
Check whether enable recovery and execution mode consistency.
|
|
1102
1201
|
"""
|
|
1103
1202
|
|
|
1104
|
-
enable_recovery = _get_recovery_context("enable_recovery")
|
|
1203
|
+
enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
|
|
1105
1204
|
if not enable_recovery:
|
|
1106
1205
|
self.enable_recovery = False
|
|
1107
1206
|
else:
|
|
@@ -1118,6 +1217,8 @@ class Model:
|
|
|
1118
1217
|
dataset_size (int): The number of batches in a dataset.
|
|
1119
1218
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1120
1219
|
"""
|
|
1220
|
+
if context.get_context("device_target") != "GPU":
|
|
1221
|
+
return
|
|
1121
1222
|
if not self.enable_recovery:
|
|
1122
1223
|
self.need_load_ckpt = False
|
|
1123
1224
|
|
|
@@ -1146,7 +1247,7 @@ class Model:
|
|
|
1146
1247
|
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1147
1248
|
except BaseException as e:
|
|
1148
1249
|
os.remove(cb_params.latest_ckpt_file)
|
|
1149
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1250
|
+
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
|
|
1150
1251
|
+ cb_params.latest_ckpt_file) from e
|
|
1151
1252
|
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1152
1253
|
self.need_load_ckpt = False
|
|
@@ -1236,6 +1337,9 @@ class Model:
|
|
|
1236
1337
|
self._loss_scale_manager.update_loss_scale(overflow)
|
|
1237
1338
|
|
|
1238
1339
|
list_callback.on_train_step_end(run_context)
|
|
1340
|
+
if cb_params.is_arf:
|
|
1341
|
+
cb_params.is_arf = False
|
|
1342
|
+
_set_recovery_context(is_arf=False)
|
|
1239
1343
|
# Embedding cache server only run one step.
|
|
1240
1344
|
if is_embedding_cache_server:
|
|
1241
1345
|
break
|
|
@@ -1333,10 +1437,9 @@ class Model:
|
|
|
1333
1437
|
... loss_scale_manager=loss_scale_manager)
|
|
1334
1438
|
>>> model.train(2, dataset)
|
|
1335
1439
|
"""
|
|
1440
|
+
_init_auto_parallel_context(self._network)
|
|
1336
1441
|
_check_tft()
|
|
1337
1442
|
device_target = context.get_context("device_target")
|
|
1338
|
-
# prepare dataset for obfuscated model
|
|
1339
|
-
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1340
1443
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1341
1444
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1342
1445
|
dataset_sink_mode = False
|
|
@@ -1392,6 +1495,8 @@ class Model:
|
|
|
1392
1495
|
if _enable_distributed_mindrt():
|
|
1393
1496
|
_reset_op_id_with_offset()
|
|
1394
1497
|
|
|
1498
|
+
_clear_auto_parallel_context(self._network)
|
|
1499
|
+
|
|
1395
1500
|
@staticmethod
|
|
1396
1501
|
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
|
|
1397
1502
|
if get_debug_mode() and dataset_sink_mode:
|
|
@@ -1485,11 +1590,8 @@ class Model:
|
|
|
1485
1590
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1486
1591
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
|
1487
1592
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1488
|
-
|
|
1489
|
-
Tutorial Examples:
|
|
1490
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1491
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1492
1593
|
"""
|
|
1594
|
+
_init_auto_parallel_context(self._network)
|
|
1493
1595
|
device_target = context.get_context("device_target")
|
|
1494
1596
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1495
1597
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
@@ -1541,8 +1643,9 @@ class Model:
|
|
|
1541
1643
|
valid_dataset=valid_dataset,
|
|
1542
1644
|
valid_frequency=valid_frequency,
|
|
1543
1645
|
valid_dataset_sink_mode=valid_dataset_sink_mode)
|
|
1646
|
+
_clear_auto_parallel_context(self._network)
|
|
1544
1647
|
|
|
1545
|
-
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
1648
|
+
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
1546
1649
|
"""
|
|
1547
1650
|
Build computational graphs and data graphs with the sink mode.
|
|
1548
1651
|
|
|
@@ -1561,6 +1664,7 @@ class Model:
|
|
|
1561
1664
|
will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
|
|
1562
1665
|
sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
|
|
1563
1666
|
epoch (int): Control the training epochs. Default: ``1`` .
|
|
1667
|
+
sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
|
|
1564
1668
|
|
|
1565
1669
|
Examples:
|
|
1566
1670
|
>>> from mindspore import nn
|
|
@@ -1581,20 +1685,22 @@ class Model:
|
|
|
1581
1685
|
>>> model.build(dataset, epoch=2)
|
|
1582
1686
|
>>> model.train(2, dataset)
|
|
1583
1687
|
"""
|
|
1688
|
+
_init_auto_parallel_context(self._network)
|
|
1584
1689
|
epoch = Validator.check_positive_int(epoch)
|
|
1585
1690
|
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1586
1691
|
self._train_network.check_names_and_refresh_name()
|
|
1587
1692
|
self._train_network._is_check_and_refresh = True
|
|
1588
1693
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
|
|
1589
1694
|
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1590
|
-
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1695
|
+
self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
|
|
1591
1696
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1592
1697
|
"The model.build() which contains dataset warmup and network compile is success.")
|
|
1593
1698
|
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1699
|
+
_clear_auto_parallel_context(self._network)
|
|
1594
1700
|
|
|
1595
1701
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1596
1702
|
"""
|
|
1597
|
-
Evaluation process in
|
|
1703
|
+
Evaluation process in :func:`mindspore.train.Model.fit`.
|
|
1598
1704
|
|
|
1599
1705
|
Args:
|
|
1600
1706
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
@@ -1670,6 +1776,9 @@ class Model:
|
|
|
1670
1776
|
cb_params.eval_results.update({"eval_loss": eval_loss})
|
|
1671
1777
|
list_callback.on_eval_end(run_context)
|
|
1672
1778
|
|
|
1779
|
+
dataset_helper.stop_send()
|
|
1780
|
+
dataset_helper.release()
|
|
1781
|
+
|
|
1673
1782
|
return metrics
|
|
1674
1783
|
|
|
1675
1784
|
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
|
@@ -1757,12 +1866,8 @@ class Model:
|
|
|
1757
1866
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1758
1867
|
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
|
1759
1868
|
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
|
1760
|
-
|
|
1761
|
-
Tutorial Examples:
|
|
1762
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1763
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1764
1869
|
"""
|
|
1765
|
-
|
|
1870
|
+
_init_auto_parallel_context(self._network)
|
|
1766
1871
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1767
1872
|
|
|
1768
1873
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
@@ -1780,10 +1885,6 @@ class Model:
|
|
|
1780
1885
|
cb_params.mode = "eval"
|
|
1781
1886
|
cb_params.cur_step_num = 0
|
|
1782
1887
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
1783
|
-
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
1784
|
-
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
1785
|
-
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
1786
|
-
cb_params.batch_num, full_flops=False))
|
|
1787
1888
|
cb_params.network = self._network
|
|
1788
1889
|
|
|
1789
1890
|
self._clear_metrics()
|
|
@@ -1811,6 +1912,7 @@ class Model:
|
|
|
1811
1912
|
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1812
1913
|
if _enable_distributed_mindrt():
|
|
1813
1914
|
_reset_op_id_with_offset()
|
|
1915
|
+
_clear_auto_parallel_context(self._network)
|
|
1814
1916
|
|
|
1815
1917
|
return eval_result
|
|
1816
1918
|
|
|
@@ -1823,7 +1925,8 @@ class Model:
|
|
|
1823
1925
|
The predict data, can be a single tensor,
|
|
1824
1926
|
a list of tensor, or a tuple of tensor.
|
|
1825
1927
|
|
|
1826
|
-
config (dict, optional)
|
|
1928
|
+
config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
|
|
1929
|
+
|
|
1827
1930
|
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1828
1931
|
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1829
1932
|
table file for inference. The content of the configuration file is as follows:
|
|
@@ -1833,6 +1936,16 @@ class Model:
|
|
|
1833
1936
|
For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
|
|
1834
1937
|
config.ini file:
|
|
1835
1938
|
|
|
1939
|
+
The config has 3 forms:
|
|
1940
|
+
1. configPath defines the path of the configuration file, which is used to pass user-defined
|
|
1941
|
+
options during model building. Default value: ``"" ``.
|
|
1942
|
+
|
|
1943
|
+
.. code-block::
|
|
1944
|
+
|
|
1945
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1946
|
+
|
|
1947
|
+
Here is the content of the config.ini file:
|
|
1948
|
+
|
|
1836
1949
|
.. code-block::
|
|
1837
1950
|
|
|
1838
1951
|
[ascend_context]
|
|
@@ -1841,20 +1954,15 @@ class Model:
|
|
|
1841
1954
|
[op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
|
|
1842
1955
|
[op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
|
|
1843
1956
|
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
.. code-block::
|
|
1847
|
-
|
|
1848
|
-
config = {"configPath" : "/home/user/config.ini"}
|
|
1849
|
-
|
|
1850
|
-
When only the config_dict is configured, it is done as follows:
|
|
1957
|
+
2. Set the user-defined options in parameter dictionary, it is done as follows:
|
|
1851
1958
|
|
|
1852
1959
|
.. code-block::
|
|
1853
1960
|
|
|
1854
1961
|
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1855
1962
|
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1856
1963
|
|
|
1857
|
-
|
|
1964
|
+
3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
|
|
1965
|
+
dictionary is higher than that of the content in the configuration file. It is done as follows:
|
|
1858
1966
|
|
|
1859
1967
|
.. code-block::
|
|
1860
1968
|
|
|
@@ -1862,12 +1970,13 @@ class Model:
|
|
|
1862
1970
|
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1863
1971
|
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1864
1972
|
|
|
1865
|
-
Note that
|
|
1866
|
-
|
|
1973
|
+
Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
|
|
1974
|
+
as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
|
|
1867
1975
|
|
|
1868
1976
|
Returns:
|
|
1869
1977
|
Tensor, array(s) of predictions.
|
|
1870
1978
|
"""
|
|
1979
|
+
|
|
1871
1980
|
def _get_lite_context(lite_context_input):
|
|
1872
1981
|
# use default lite context parameters for now
|
|
1873
1982
|
device_target = context.get_context("device_target").lower()
|
|
@@ -1901,7 +2010,7 @@ class Model:
|
|
|
1901
2010
|
if not self._mindspore_lite:
|
|
1902
2011
|
self._mindspore_lite = importlib.import_module('mindspore_lite')
|
|
1903
2012
|
|
|
1904
|
-
use_past = False
|
|
2013
|
+
use_past = False # default execute full model inference
|
|
1905
2014
|
model_group_id = None
|
|
1906
2015
|
if self._predict_network.get_flags().__contains__("is_first_iteration"):
|
|
1907
2016
|
is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
|
|
@@ -2014,6 +2123,7 @@ class Model:
|
|
|
2014
2123
|
>>> model = Model(LeNet5())
|
|
2015
2124
|
>>> result = model.predict(input_data)
|
|
2016
2125
|
"""
|
|
2126
|
+
_init_auto_parallel_context(self._network)
|
|
2017
2127
|
if backend not in ['lite', None]:
|
|
2018
2128
|
raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
|
|
2019
2129
|
if backend == "lite" and self._lite_infer:
|
|
@@ -2029,6 +2139,7 @@ class Model:
|
|
|
2029
2139
|
except BaseException as e:
|
|
2030
2140
|
self._lite_infer = False
|
|
2031
2141
|
logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
|
|
2142
|
+
_clear_auto_parallel_context(self._network)
|
|
2032
2143
|
|
|
2033
2144
|
def _check_input_data():
|
|
2034
2145
|
"""Input data check."""
|
|
@@ -2094,7 +2205,9 @@ class Model:
|
|
|
2094
2205
|
|
|
2095
2206
|
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
|
2096
2207
|
"""
|
|
2097
|
-
Generate parameter layout for the train network
|
|
2208
|
+
Generate parameter layout for the train network when using `AutoParallel(cell)`
|
|
2209
|
+
to enable parallel mode.
|
|
2210
|
+
|
|
2098
2211
|
Only dataset sink mode is supported for now.
|
|
2099
2212
|
|
|
2100
2213
|
.. warning::
|
|
@@ -2113,9 +2226,9 @@ class Model:
|
|
|
2113
2226
|
Configure pynative mode or CPU, the training process will be performed with
|
|
2114
2227
|
dataset not sink. Default: ``True`` .
|
|
2115
2228
|
sink_size (int): Control the number of steps for each sinking.
|
|
2229
|
+
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2116
2230
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
2117
2231
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
2118
|
-
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2119
2232
|
Default: ``-1`` .
|
|
2120
2233
|
|
|
2121
2234
|
Returns:
|
|
@@ -2129,10 +2242,10 @@ class Model:
|
|
|
2129
2242
|
>>> from mindspore import Tensor, nn
|
|
2130
2243
|
>>> from mindspore.train import Model
|
|
2131
2244
|
>>> from mindspore.communication import init
|
|
2245
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2132
2246
|
>>>
|
|
2133
2247
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2134
2248
|
>>> init()
|
|
2135
|
-
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
2136
2249
|
>>>
|
|
2137
2250
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
2138
2251
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
@@ -2140,13 +2253,15 @@ class Model:
|
|
|
2140
2253
|
>>> # Define the network structure of LeNet5. Refer to
|
|
2141
2254
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
2142
2255
|
>>> net = LeNet5()
|
|
2256
|
+
>>> parallel_net = AutoParallel(net)
|
|
2143
2257
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
2144
2258
|
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
|
2145
2259
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
2146
|
-
>>> model = Model(
|
|
2260
|
+
>>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
2147
2261
|
... loss_scale_manager=loss_scale_manager)
|
|
2148
2262
|
>>> layout_dict = model.infer_train_layout(dataset)
|
|
2149
2263
|
"""
|
|
2264
|
+
_init_auto_parallel_context(self._network)
|
|
2150
2265
|
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
|
2151
2266
|
|
|
2152
2267
|
train_dataset.__no_send__ = True
|
|
@@ -2158,11 +2273,13 @@ class Model:
|
|
|
2158
2273
|
train_network.compile(*inputs)
|
|
2159
2274
|
break
|
|
2160
2275
|
train_dataset.__model_hash__ = hash(self)
|
|
2276
|
+
_clear_auto_parallel_context(self._network)
|
|
2161
2277
|
return train_network.parameter_layout_dict
|
|
2162
2278
|
|
|
2163
2279
|
def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
|
|
2164
2280
|
"""
|
|
2165
|
-
Generate parameter layout for the predict network
|
|
2281
|
+
Generate parameter layout for the predict network when using `AutoParallel(cell)`
|
|
2282
|
+
to enable parallel mode.
|
|
2166
2283
|
|
|
2167
2284
|
Data could be a single tensor or multiple tensors.
|
|
2168
2285
|
|
|
@@ -2185,21 +2302,47 @@ class Model:
|
|
|
2185
2302
|
RuntimeError: If not in GRAPH_MODE.
|
|
2186
2303
|
|
|
2187
2304
|
Examples:
|
|
2188
|
-
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
|
2189
|
-
>>> # mindspore.cn.
|
|
2190
2305
|
>>> import numpy as np
|
|
2191
|
-
>>> import mindspore as
|
|
2306
|
+
>>> import mindspore.nn as nn
|
|
2192
2307
|
>>> from mindspore import Tensor
|
|
2193
2308
|
>>> from mindspore.train import Model
|
|
2309
|
+
>>> from mindspore.ops import operations as P
|
|
2310
|
+
>>> from mindspore import context
|
|
2194
2311
|
>>> from mindspore.communication import init
|
|
2312
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2313
|
+
>>>
|
|
2314
|
+
>>> class Net(nn.Cell):
|
|
2315
|
+
>>> def __init__(self):
|
|
2316
|
+
>>> super(Net, self).__init__()
|
|
2317
|
+
>>> self.fc1 = nn.Dense(128, 768, activation='relu')
|
|
2318
|
+
>>> self.fc2 = nn.Dense(128, 768, activation='relu')
|
|
2319
|
+
>>> self.fc3 = nn.Dense(128, 768, activation='relu')
|
|
2320
|
+
>>> self.fc4 = nn.Dense(768, 768, activation='relu')
|
|
2321
|
+
>>> self.relu4 = nn.ReLU()
|
|
2322
|
+
>>> self.relu5 = nn.ReLU()
|
|
2323
|
+
>>> self.transpose = P.Transpose()
|
|
2324
|
+
>>> self.matmul1 = P.MatMul()
|
|
2325
|
+
>>> self.matmul2 = P.MatMul()
|
|
2326
|
+
>>>
|
|
2327
|
+
>>> def construct(self, x):
|
|
2328
|
+
>>> q = self.fc1(x)
|
|
2329
|
+
>>> k = self.fc2(x)
|
|
2330
|
+
>>> v = self.fc3(x)
|
|
2331
|
+
>>> k = self.transpose(k, (1, 0))
|
|
2332
|
+
>>> c = self.relu4(self.matmul1(q, k))
|
|
2333
|
+
>>> s = self.relu5(self.matmul2(c, v))
|
|
2334
|
+
>>> s = self.fc4(s)
|
|
2335
|
+
>>> return s
|
|
2195
2336
|
>>>
|
|
2196
2337
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2197
2338
|
>>> init()
|
|
2198
|
-
>>>
|
|
2199
|
-
>>>
|
|
2200
|
-
>>>
|
|
2201
|
-
>>>
|
|
2339
|
+
>>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
|
2340
|
+
>>> net = Net()
|
|
2341
|
+
>>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
|
|
2342
|
+
>>> model = Model(parallel_net)
|
|
2343
|
+
>>> predict_map = model.infer_predict_layout(inputs)
|
|
2202
2344
|
"""
|
|
2345
|
+
_init_auto_parallel_context(self._network)
|
|
2203
2346
|
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2204
2347
|
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2205
2348
|
"only supports GRAPH MODE and Ascend target currently.")
|
|
@@ -2219,6 +2362,7 @@ class Model:
|
|
|
2219
2362
|
predict_net.phase = origin_phase
|
|
2220
2363
|
else:
|
|
2221
2364
|
predict_net.compile(*predict_data)
|
|
2365
|
+
_clear_auto_parallel_context(self._network)
|
|
2222
2366
|
return predict_net.parameter_layout_dict
|
|
2223
2367
|
|
|
2224
2368
|
def _flush_from_cache(self, cb_params):
|
|
@@ -2258,16 +2402,5 @@ class Model:
|
|
|
2258
2402
|
"""
|
|
2259
2403
|
return self._eval_network
|
|
2260
2404
|
|
|
2261
|
-
def _prepare_obf_dataset(self, dataset):
|
|
2262
|
-
if not hasattr(self._network, 'obf_ratios'):
|
|
2263
|
-
return dataset
|
|
2264
|
-
data_size = dataset.get_dataset_size()
|
|
2265
|
-
obf_ratio_dataset = []
|
|
2266
|
-
for _ in range(data_size):
|
|
2267
|
-
obf_ratio_dataset.append(self._network.obf_ratios)
|
|
2268
|
-
obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
|
|
2269
|
-
dataset = ds.zip((dataset, obf_ratio_dataset))
|
|
2270
|
-
return dataset
|
|
2271
|
-
|
|
2272
2405
|
|
|
2273
2406
|
__all__ = ["Model"]
|