mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -27,7 +27,6 @@ import time
|
|
|
27
27
|
import numpy as np
|
|
28
28
|
|
|
29
29
|
import mindspore
|
|
30
|
-
import mindspore.dataset as ds
|
|
31
30
|
from mindspore import log as logger
|
|
32
31
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
33
32
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
35
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
36
|
from mindspore import _checkparam as Validator
|
|
38
37
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
|
|
38
|
+
TrainFaultTolerance
|
|
40
39
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
40
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
41
|
from mindspore import context
|
|
@@ -57,7 +56,11 @@ from mindspore.dataset.core.config import get_debug_mode
|
|
|
57
56
|
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
|
58
57
|
from mindspore.train import amp
|
|
59
58
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
|
+
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
|
+
from mindspore._c_expression import clean_tdt_channel
|
|
60
61
|
|
|
62
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
63
|
+
from .serialization import load_param_into_net
|
|
61
64
|
|
|
62
65
|
def _transfer_tensor_to_tuple(inputs):
|
|
63
66
|
"""
|
|
@@ -91,6 +94,7 @@ def _save_final_ckpt(func):
|
|
|
91
94
|
"""
|
|
92
95
|
Decorator function, which saves the current checkpoint when an exception occurs during training.
|
|
93
96
|
"""
|
|
97
|
+
|
|
94
98
|
@wraps(func)
|
|
95
99
|
def wrapper(self, *args, **kwargs):
|
|
96
100
|
obj = None
|
|
@@ -107,7 +111,7 @@ def _save_final_ckpt(func):
|
|
|
107
111
|
# pylint: disable=W0212
|
|
108
112
|
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
|
|
109
113
|
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
|
|
110
|
-
|
|
114
|
+
+ str(self._current_step_num) + "_breakpoint.ckpt"
|
|
111
115
|
cur_file = os.path.join(obj._directory, cur_ckpoint_file)
|
|
112
116
|
if "epoch_num" in obj._append_dict:
|
|
113
117
|
obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
|
|
@@ -118,88 +122,172 @@ def _save_final_ckpt(func):
|
|
|
118
122
|
raise e
|
|
119
123
|
else:
|
|
120
124
|
func(self, *args, **kwargs)
|
|
125
|
+
|
|
121
126
|
return wrapper
|
|
122
127
|
|
|
128
|
+
|
|
129
|
+
def _handle_exception_info(obj, uce_env, tft, e):
|
|
130
|
+
"""handle exception info"""
|
|
131
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
132
|
+
if not uce_env:
|
|
133
|
+
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
134
|
+
exc_info=True)
|
|
135
|
+
if tft:
|
|
136
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
137
|
+
raise e
|
|
138
|
+
e_str = str(e)
|
|
139
|
+
logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
140
|
+
if "UCEError" in e_str:
|
|
141
|
+
logger.info("uce wrapper report UCEError")
|
|
142
|
+
obj.is_uce_rank = True
|
|
143
|
+
# if error is HBM_MULTI_BIT_ECC_ERROR
|
|
144
|
+
if "error_code=507054" in e_str:
|
|
145
|
+
hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
|
|
146
|
+
can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
|
|
147
|
+
logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
|
|
148
|
+
hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
|
|
149
|
+
optimizer_end={optimizer_end}, can_repair={can_repair}")
|
|
150
|
+
if not can_repair:
|
|
151
|
+
logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
|
|
152
|
+
f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
|
|
153
|
+
f"optimizer_end={optimizer_end}", exc_info=True)
|
|
154
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
155
|
+
raise e
|
|
156
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
157
|
+
elif "ForceStopError" in e_str:
|
|
158
|
+
logger.warning("uce wrapper caught RuntimeError ForceStopError")
|
|
159
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
160
|
+
tft.tft_report_error(force_stop_err)
|
|
161
|
+
elif "ARF FINISH" in e_str:
|
|
162
|
+
logger.warning(f"ARF FINISH")
|
|
163
|
+
_set_recovery_context(is_arf=True)
|
|
164
|
+
tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
|
|
165
|
+
else:
|
|
166
|
+
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
167
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
168
|
+
raise e
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _handle_training_result_error(model, tft_obj):
|
|
172
|
+
"""
|
|
173
|
+
Handle training result error for resuming training.
|
|
174
|
+
"""
|
|
175
|
+
ckpt_load_fn = tft_obj.ckpt_load_func
|
|
176
|
+
train_network = tft_obj.cb_params.train_network
|
|
177
|
+
logger.warning("Process training result error start.")
|
|
178
|
+
# 1. Clear tdt channel
|
|
179
|
+
logger.warning("Clean tdt channel.")
|
|
180
|
+
clean_tdt_channel()
|
|
181
|
+
|
|
182
|
+
# 2. Load checkpoint
|
|
183
|
+
logger.warning("Load checkpoint.")
|
|
184
|
+
new_param_dict, remove_redundancy = ckpt_load_fn()
|
|
185
|
+
param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
|
|
186
|
+
logger.warning(f"param_not_load: {param_not_load}")
|
|
187
|
+
logger.warning(f"ckpt_not_load: {ckpt_not_load}")
|
|
188
|
+
resume_epoch = new_param_dict.get('epoch_num')
|
|
189
|
+
resume_step = new_param_dict.get('step_num')
|
|
190
|
+
model._initial_step = int(resume_step.asnumpy())
|
|
191
|
+
logger.warning("Process training result error end.")
|
|
192
|
+
return (resume_epoch, resume_step)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
|
|
196
|
+
"""calculate initial step for callback"""
|
|
197
|
+
train_dataset = args[1]
|
|
198
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
199
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
200
|
+
|
|
201
|
+
cb_initial_step = 0
|
|
202
|
+
if dataset_sink_mode:
|
|
203
|
+
train_dataset.set_init_step(org_epoch)
|
|
204
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
205
|
+
if sink_size != -1:
|
|
206
|
+
cb_initial_step = org_epoch * sink_size + org_step
|
|
207
|
+
else:
|
|
208
|
+
cb_initial_step = org_epoch * dataset_size + org_step
|
|
209
|
+
else:
|
|
210
|
+
train_dataset.set_init_step(org_step)
|
|
211
|
+
cb_initial_step = org_step
|
|
212
|
+
if hasattr(train_dataset, '_dataset_helper'):
|
|
213
|
+
dataset_helper = train_dataset._dataset_helper
|
|
214
|
+
_reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
|
|
215
|
+
return cb_initial_step
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _update_ckpt_callback_info(resume_train_step, **kwargs):
|
|
219
|
+
"""
|
|
220
|
+
Update checkpoint callback internal state
|
|
221
|
+
"""
|
|
222
|
+
ckpt_obj = None
|
|
223
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
|
224
|
+
ckpt_obj = kwargs.get('callbacks')
|
|
225
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
226
|
+
for item in kwargs.get('callbacks'):
|
|
227
|
+
if isinstance(item, ModelCheckpoint):
|
|
228
|
+
ckpt_obj = item
|
|
229
|
+
if ckpt_obj is not None:
|
|
230
|
+
ckpt_obj._last_triggered_step = 0
|
|
231
|
+
ckpt_obj._append_step_num = resume_train_step
|
|
232
|
+
|
|
233
|
+
|
|
123
234
|
def _handle_tft(func):
|
|
124
235
|
"""
|
|
125
236
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
126
237
|
"""
|
|
238
|
+
|
|
127
239
|
@wraps(func)
|
|
128
240
|
def wrapper(self, *args, **kwargs):
|
|
129
241
|
obj = None
|
|
130
|
-
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'),
|
|
242
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
|
|
131
243
|
obj = kwargs.get('callbacks')
|
|
132
244
|
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
133
245
|
for item in kwargs.get('callbacks'):
|
|
134
|
-
if isinstance(item,
|
|
246
|
+
if isinstance(item, TrainFaultTolerance):
|
|
135
247
|
obj = item
|
|
136
248
|
if obj:
|
|
137
249
|
tft = obj.tft
|
|
138
250
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
139
|
-
uce_env = "UCE:1" in tft_env
|
|
251
|
+
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
|
|
252
|
+
tre_env = "TRE:1" in tft_env
|
|
140
253
|
while True:
|
|
141
254
|
try:
|
|
142
255
|
return func(self, *args, **kwargs)
|
|
143
256
|
except RuntimeError as e:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
raise e
|
|
150
|
-
e_str = str(e)
|
|
151
|
-
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
152
|
-
if "UCEError" in e_str:
|
|
153
|
-
logger.info("uce wrapper report UCEError")
|
|
154
|
-
obj.is_uce_rank = True
|
|
155
|
-
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
156
|
-
elif "ForceStopError" in e_str:
|
|
157
|
-
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
158
|
-
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
159
|
-
tft.tft_report_error(force_stop_err)
|
|
257
|
+
if tre_env and 'TREError' in str(e):
|
|
258
|
+
_, resume_step = _handle_training_result_error(self, obj)
|
|
259
|
+
repair_step = int(resume_step.asnumpy())
|
|
260
|
+
_update_ckpt_callback_info(repair_step, **kwargs)
|
|
261
|
+
logger.warning(f'Resume training after TREError from step {repair_step}.')
|
|
160
262
|
else:
|
|
161
|
-
|
|
162
|
-
tft.
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
initial_epoch = int(repair_step/self.batch_num)
|
|
263
|
+
_handle_exception_info(obj, uce_env, tft, e)
|
|
264
|
+
ret = tft.tft_wait_next_action()
|
|
265
|
+
if ret == tft.Action.EXIT.value:
|
|
266
|
+
raise e
|
|
267
|
+
repair_step = tft.tft_get_repair_step()
|
|
268
|
+
logger.warning(
|
|
269
|
+
"uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
|
|
270
|
+
self.batch_num))
|
|
271
|
+
initial_epoch = int(repair_step / self.batch_num)
|
|
171
272
|
initial_step = repair_step % self.batch_num
|
|
172
273
|
kwargs["initial_epoch"] = initial_epoch
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
177
|
-
|
|
178
|
-
cb_initial_step = 0
|
|
179
|
-
if dataset_sink_mode:
|
|
180
|
-
train_dataset.set_init_step(initial_epoch)
|
|
181
|
-
dataset_size = train_dataset.get_dataset_size()
|
|
182
|
-
if sink_size != -1:
|
|
183
|
-
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
184
|
-
else:
|
|
185
|
-
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
186
|
-
else:
|
|
187
|
-
train_dataset.set_init_step(initial_step)
|
|
188
|
-
cb_initial_step = initial_step
|
|
189
|
-
|
|
190
|
-
kwargs["initial_step"] = cb_initial_step
|
|
274
|
+
cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
|
|
275
|
+
if not self.enable_tre:
|
|
276
|
+
kwargs["initial_step"] = cb_initial_step
|
|
191
277
|
# reset all accu grads to zero
|
|
192
278
|
obj._reset_acc_grads()
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
279
|
+
logger.warning(
|
|
280
|
+
"uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
|
|
281
|
+
cb_initial_step))
|
|
196
282
|
continue
|
|
197
283
|
except BaseException as e:
|
|
198
|
-
|
|
199
|
-
|
|
284
|
+
if tft:
|
|
285
|
+
logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
|
|
286
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
200
287
|
raise e
|
|
201
288
|
else:
|
|
202
289
|
return func(self, *args, **kwargs)
|
|
290
|
+
|
|
203
291
|
return wrapper
|
|
204
292
|
|
|
205
293
|
|
|
@@ -216,7 +304,7 @@ def _check_tft():
|
|
|
216
304
|
if ms_mode != mindspore.GRAPH_MODE:
|
|
217
305
|
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
218
306
|
jit_level = context.get_context("jit_level")
|
|
219
|
-
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
307
|
+
if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
|
|
220
308
|
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
221
309
|
|
|
222
310
|
|
|
@@ -406,12 +494,13 @@ class Model:
|
|
|
406
494
|
the accuracy is reduced by less than 3%.
|
|
407
495
|
|
|
408
496
|
If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
|
|
409
|
-
In order for this function to work, you need to set the optimizer
|
|
410
|
-
at the
|
|
497
|
+
In order for this function to work, you need to set the parameter `optimizer`, along with
|
|
498
|
+
at least one of the parameter `eval_network` or performance `metrics`.
|
|
411
499
|
|
|
412
500
|
Notice: The current optimization enabled by default only applies to some networks, and not all networks
|
|
413
501
|
can obtain the same benefits. It is recommended to enable this function on
|
|
414
|
-
the Graph mode + Ascend platform, and for better acceleration,
|
|
502
|
+
the Graph mode + Ascend platform, and for better acceleration,
|
|
503
|
+
refer to :class:`mindspore.boost.AutoBoost` to configure
|
|
415
504
|
boost_config_dict.
|
|
416
505
|
|
|
417
506
|
Examples:
|
|
@@ -436,6 +525,7 @@ class Model:
|
|
|
436
525
|
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
|
|
437
526
|
amp_level="O0", boost_level="O0", **kwargs):
|
|
438
527
|
self._network = network
|
|
528
|
+
_init_auto_parallel_context(self._network)
|
|
439
529
|
self._loss_fn = loss_fn
|
|
440
530
|
self._optimizer = optimizer
|
|
441
531
|
self._loss_scale_manager = None
|
|
@@ -470,6 +560,9 @@ class Model:
|
|
|
470
560
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
471
561
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
472
562
|
self.batch_num = -1
|
|
563
|
+
self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
|
|
564
|
+
self._initial_step = None
|
|
565
|
+
_clear_auto_parallel_context(self._network)
|
|
473
566
|
|
|
474
567
|
def _check_for_graph_cell(self, kwargs):
|
|
475
568
|
"""Check for graph cell"""
|
|
@@ -668,7 +761,7 @@ class Model:
|
|
|
668
761
|
logger.info("Begin to connect network with dataset.")
|
|
669
762
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
670
763
|
|
|
671
|
-
if _get_recovery_context("enable_recovery") and is_train:
|
|
764
|
+
if (_get_recovery_context("enable_recovery") or self.enable_tre) and is_train:
|
|
672
765
|
_set_training_dataset(dataset_helper)
|
|
673
766
|
|
|
674
767
|
network.set_train(is_train)
|
|
@@ -765,7 +858,7 @@ class Model:
|
|
|
765
858
|
break
|
|
766
859
|
logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
|
|
767
860
|
|
|
768
|
-
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
861
|
+
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
769
862
|
"""
|
|
770
863
|
Initialize compute graphs and data graphs with the sink mode.
|
|
771
864
|
|
|
@@ -794,7 +887,6 @@ class Model:
|
|
|
794
887
|
if not isinstance(train_dataset, mindspore.dataset.Dataset):
|
|
795
888
|
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
796
889
|
"but got {}.".format(type(train_dataset)))
|
|
797
|
-
|
|
798
890
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
799
891
|
"Begin to check parameter broadcast in model.build().")
|
|
800
892
|
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
@@ -807,23 +899,24 @@ class Model:
|
|
|
807
899
|
train_dataset.__no_send__ = True
|
|
808
900
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
809
901
|
dataset=train_dataset,
|
|
810
|
-
dataset_sink_mode=
|
|
902
|
+
dataset_sink_mode=sink_mode,
|
|
811
903
|
sink_size=sink_size)
|
|
812
904
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
|
|
813
|
-
|
|
814
|
-
|
|
905
|
+
if sink_mode:
|
|
906
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
907
|
+
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
815
908
|
|
|
816
|
-
|
|
817
|
-
|
|
909
|
+
# Since dataset pipeline has been triggered, delete flag
|
|
910
|
+
delattr(train_dataset, "__no_send__")
|
|
818
911
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
912
|
+
# Waiting for the dataset warmup ready
|
|
913
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
914
|
+
"Begin waiting for dataset warmup in model.build().")
|
|
915
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
916
|
+
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
917
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
918
|
+
"The dataset warmup was successful in model.build().")
|
|
919
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
827
920
|
|
|
828
921
|
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
829
922
|
train_network.add_flags_recursive(is_first_iteration=True)
|
|
@@ -833,6 +926,7 @@ class Model:
|
|
|
833
926
|
logger.info("Begin to compile train network in model.build() procedure.")
|
|
834
927
|
train_network.compile(*inputs)
|
|
835
928
|
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
929
|
+
train_dataset.reset()
|
|
836
930
|
break
|
|
837
931
|
|
|
838
932
|
if valid_dataset:
|
|
@@ -846,7 +940,7 @@ class Model:
|
|
|
846
940
|
valid_dataset.__no_send__ = True
|
|
847
941
|
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
|
848
942
|
dataset=valid_dataset,
|
|
849
|
-
dataset_sink_mode=
|
|
943
|
+
dataset_sink_mode=sink_mode)
|
|
850
944
|
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
851
945
|
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
852
946
|
for inputs in valid_dataset_helper:
|
|
@@ -854,6 +948,7 @@ class Model:
|
|
|
854
948
|
"Begin to compile eval network in model.build().")
|
|
855
949
|
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
856
950
|
eval_network.compile(*inputs)
|
|
951
|
+
valid_dataset.reset()
|
|
857
952
|
break
|
|
858
953
|
|
|
859
954
|
@staticmethod
|
|
@@ -922,6 +1017,8 @@ class Model:
|
|
|
922
1017
|
cb_params.last_save_ckpt_step = None
|
|
923
1018
|
cb_params.latest_ckpt_file = None
|
|
924
1019
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
1020
|
+
cb_params.is_arf = _get_recovery_context("is_arf")
|
|
1021
|
+
cb_params.initial_step = self._initial_step
|
|
925
1022
|
|
|
926
1023
|
# build callback list
|
|
927
1024
|
with _CallbackManager(callbacks) as list_callback:
|
|
@@ -1026,6 +1123,9 @@ class Model:
|
|
|
1026
1123
|
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1027
1124
|
if need_exec_callback_step_end:
|
|
1028
1125
|
list_callback.on_train_step_end(run_context)
|
|
1126
|
+
if cb_params.is_arf:
|
|
1127
|
+
cb_params.is_arf = False
|
|
1128
|
+
_set_recovery_context(is_arf=False)
|
|
1029
1129
|
|
|
1030
1130
|
# Embedding cache server only run one step.
|
|
1031
1131
|
if is_embedding_cache_server:
|
|
@@ -1056,7 +1156,7 @@ class Model:
|
|
|
1056
1156
|
if should_stop:
|
|
1057
1157
|
break
|
|
1058
1158
|
|
|
1059
|
-
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
|
|
1159
|
+
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
|
|
1060
1160
|
and not _get_recovery_context("latest_ckpt_file")
|
|
1061
1161
|
self.epoch_iter += 1
|
|
1062
1162
|
if need_reset_to_beginning:
|
|
@@ -1100,7 +1200,7 @@ class Model:
|
|
|
1100
1200
|
Check whether enable recovery and execution mode consistency.
|
|
1101
1201
|
"""
|
|
1102
1202
|
|
|
1103
|
-
enable_recovery = _get_recovery_context("enable_recovery")
|
|
1203
|
+
enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
|
|
1104
1204
|
if not enable_recovery:
|
|
1105
1205
|
self.enable_recovery = False
|
|
1106
1206
|
else:
|
|
@@ -1117,6 +1217,8 @@ class Model:
|
|
|
1117
1217
|
dataset_size (int): The number of batches in a dataset.
|
|
1118
1218
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1119
1219
|
"""
|
|
1220
|
+
if context.get_context("device_target") != "GPU":
|
|
1221
|
+
return
|
|
1120
1222
|
if not self.enable_recovery:
|
|
1121
1223
|
self.need_load_ckpt = False
|
|
1122
1224
|
|
|
@@ -1145,7 +1247,7 @@ class Model:
|
|
|
1145
1247
|
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1146
1248
|
except BaseException as e:
|
|
1147
1249
|
os.remove(cb_params.latest_ckpt_file)
|
|
1148
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1250
|
+
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
|
|
1149
1251
|
+ cb_params.latest_ckpt_file) from e
|
|
1150
1252
|
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1151
1253
|
self.need_load_ckpt = False
|
|
@@ -1235,6 +1337,9 @@ class Model:
|
|
|
1235
1337
|
self._loss_scale_manager.update_loss_scale(overflow)
|
|
1236
1338
|
|
|
1237
1339
|
list_callback.on_train_step_end(run_context)
|
|
1340
|
+
if cb_params.is_arf:
|
|
1341
|
+
cb_params.is_arf = False
|
|
1342
|
+
_set_recovery_context(is_arf=False)
|
|
1238
1343
|
# Embedding cache server only run one step.
|
|
1239
1344
|
if is_embedding_cache_server:
|
|
1240
1345
|
break
|
|
@@ -1332,10 +1437,9 @@ class Model:
|
|
|
1332
1437
|
... loss_scale_manager=loss_scale_manager)
|
|
1333
1438
|
>>> model.train(2, dataset)
|
|
1334
1439
|
"""
|
|
1440
|
+
_init_auto_parallel_context(self._network)
|
|
1335
1441
|
_check_tft()
|
|
1336
1442
|
device_target = context.get_context("device_target")
|
|
1337
|
-
# prepare dataset for obfuscated model
|
|
1338
|
-
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1339
1443
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1340
1444
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1341
1445
|
dataset_sink_mode = False
|
|
@@ -1391,6 +1495,8 @@ class Model:
|
|
|
1391
1495
|
if _enable_distributed_mindrt():
|
|
1392
1496
|
_reset_op_id_with_offset()
|
|
1393
1497
|
|
|
1498
|
+
_clear_auto_parallel_context(self._network)
|
|
1499
|
+
|
|
1394
1500
|
@staticmethod
|
|
1395
1501
|
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
|
|
1396
1502
|
if get_debug_mode() and dataset_sink_mode:
|
|
@@ -1484,11 +1590,8 @@ class Model:
|
|
|
1484
1590
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1485
1591
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
|
1486
1592
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1487
|
-
|
|
1488
|
-
Tutorial Examples:
|
|
1489
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1490
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1491
1593
|
"""
|
|
1594
|
+
_init_auto_parallel_context(self._network)
|
|
1492
1595
|
device_target = context.get_context("device_target")
|
|
1493
1596
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1494
1597
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
@@ -1540,8 +1643,9 @@ class Model:
|
|
|
1540
1643
|
valid_dataset=valid_dataset,
|
|
1541
1644
|
valid_frequency=valid_frequency,
|
|
1542
1645
|
valid_dataset_sink_mode=valid_dataset_sink_mode)
|
|
1646
|
+
_clear_auto_parallel_context(self._network)
|
|
1543
1647
|
|
|
1544
|
-
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
1648
|
+
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
1545
1649
|
"""
|
|
1546
1650
|
Build computational graphs and data graphs with the sink mode.
|
|
1547
1651
|
|
|
@@ -1560,6 +1664,7 @@ class Model:
|
|
|
1560
1664
|
will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
|
|
1561
1665
|
sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
|
|
1562
1666
|
epoch (int): Control the training epochs. Default: ``1`` .
|
|
1667
|
+
sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
|
|
1563
1668
|
|
|
1564
1669
|
Examples:
|
|
1565
1670
|
>>> from mindspore import nn
|
|
@@ -1580,16 +1685,18 @@ class Model:
|
|
|
1580
1685
|
>>> model.build(dataset, epoch=2)
|
|
1581
1686
|
>>> model.train(2, dataset)
|
|
1582
1687
|
"""
|
|
1688
|
+
_init_auto_parallel_context(self._network)
|
|
1583
1689
|
epoch = Validator.check_positive_int(epoch)
|
|
1584
1690
|
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1585
1691
|
self._train_network.check_names_and_refresh_name()
|
|
1586
1692
|
self._train_network._is_check_and_refresh = True
|
|
1587
1693
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
|
|
1588
1694
|
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1589
|
-
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1695
|
+
self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
|
|
1590
1696
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1591
1697
|
"The model.build() which contains dataset warmup and network compile is success.")
|
|
1592
1698
|
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1699
|
+
_clear_auto_parallel_context(self._network)
|
|
1593
1700
|
|
|
1594
1701
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1595
1702
|
"""
|
|
@@ -1759,12 +1866,8 @@ class Model:
|
|
|
1759
1866
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1760
1867
|
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
|
1761
1868
|
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
|
1762
|
-
|
|
1763
|
-
Tutorial Examples:
|
|
1764
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1765
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1766
1869
|
"""
|
|
1767
|
-
|
|
1870
|
+
_init_auto_parallel_context(self._network)
|
|
1768
1871
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1769
1872
|
|
|
1770
1873
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
@@ -1809,6 +1912,7 @@ class Model:
|
|
|
1809
1912
|
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1810
1913
|
if _enable_distributed_mindrt():
|
|
1811
1914
|
_reset_op_id_with_offset()
|
|
1915
|
+
_clear_auto_parallel_context(self._network)
|
|
1812
1916
|
|
|
1813
1917
|
return eval_result
|
|
1814
1918
|
|
|
@@ -1821,7 +1925,8 @@ class Model:
|
|
|
1821
1925
|
The predict data, can be a single tensor,
|
|
1822
1926
|
a list of tensor, or a tuple of tensor.
|
|
1823
1927
|
|
|
1824
|
-
config (dict, optional)
|
|
1928
|
+
config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
|
|
1929
|
+
|
|
1825
1930
|
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1826
1931
|
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1827
1932
|
table file for inference. The content of the configuration file is as follows:
|
|
@@ -1831,6 +1936,16 @@ class Model:
|
|
|
1831
1936
|
For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
|
|
1832
1937
|
config.ini file:
|
|
1833
1938
|
|
|
1939
|
+
The config has 3 forms:
|
|
1940
|
+
1. configPath defines the path of the configuration file, which is used to pass user-defined
|
|
1941
|
+
options during model building. Default value: ``"" ``.
|
|
1942
|
+
|
|
1943
|
+
.. code-block::
|
|
1944
|
+
|
|
1945
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1946
|
+
|
|
1947
|
+
Here is the content of the config.ini file:
|
|
1948
|
+
|
|
1834
1949
|
.. code-block::
|
|
1835
1950
|
|
|
1836
1951
|
[ascend_context]
|
|
@@ -1839,20 +1954,15 @@ class Model:
|
|
|
1839
1954
|
[op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
|
|
1840
1955
|
[op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
|
|
1841
1956
|
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
.. code-block::
|
|
1845
|
-
|
|
1846
|
-
config = {"configPath" : "/home/user/config.ini"}
|
|
1847
|
-
|
|
1848
|
-
When only the config_dict is configured, it is done as follows:
|
|
1957
|
+
2. Set the user-defined options in parameter dictionary, it is done as follows:
|
|
1849
1958
|
|
|
1850
1959
|
.. code-block::
|
|
1851
1960
|
|
|
1852
1961
|
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1853
1962
|
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1854
1963
|
|
|
1855
|
-
|
|
1964
|
+
3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
|
|
1965
|
+
dictionary is higher than that of the content in the configuration file. It is done as follows:
|
|
1856
1966
|
|
|
1857
1967
|
.. code-block::
|
|
1858
1968
|
|
|
@@ -1860,12 +1970,13 @@ class Model:
|
|
|
1860
1970
|
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1861
1971
|
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1862
1972
|
|
|
1863
|
-
Note that
|
|
1864
|
-
|
|
1973
|
+
Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
|
|
1974
|
+
as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
|
|
1865
1975
|
|
|
1866
1976
|
Returns:
|
|
1867
1977
|
Tensor, array(s) of predictions.
|
|
1868
1978
|
"""
|
|
1979
|
+
|
|
1869
1980
|
def _get_lite_context(lite_context_input):
|
|
1870
1981
|
# use default lite context parameters for now
|
|
1871
1982
|
device_target = context.get_context("device_target").lower()
|
|
@@ -1899,7 +2010,7 @@ class Model:
|
|
|
1899
2010
|
if not self._mindspore_lite:
|
|
1900
2011
|
self._mindspore_lite = importlib.import_module('mindspore_lite')
|
|
1901
2012
|
|
|
1902
|
-
use_past = False
|
|
2013
|
+
use_past = False # default execute full model inference
|
|
1903
2014
|
model_group_id = None
|
|
1904
2015
|
if self._predict_network.get_flags().__contains__("is_first_iteration"):
|
|
1905
2016
|
is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
|
|
@@ -2012,6 +2123,7 @@ class Model:
|
|
|
2012
2123
|
>>> model = Model(LeNet5())
|
|
2013
2124
|
>>> result = model.predict(input_data)
|
|
2014
2125
|
"""
|
|
2126
|
+
_init_auto_parallel_context(self._network)
|
|
2015
2127
|
if backend not in ['lite', None]:
|
|
2016
2128
|
raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
|
|
2017
2129
|
if backend == "lite" and self._lite_infer:
|
|
@@ -2027,6 +2139,7 @@ class Model:
|
|
|
2027
2139
|
except BaseException as e:
|
|
2028
2140
|
self._lite_infer = False
|
|
2029
2141
|
logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
|
|
2142
|
+
_clear_auto_parallel_context(self._network)
|
|
2030
2143
|
|
|
2031
2144
|
def _check_input_data():
|
|
2032
2145
|
"""Input data check."""
|
|
@@ -2092,7 +2205,9 @@ class Model:
|
|
|
2092
2205
|
|
|
2093
2206
|
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
|
2094
2207
|
"""
|
|
2095
|
-
Generate parameter layout for the train network
|
|
2208
|
+
Generate parameter layout for the train network when using `AutoParallel(cell)`
|
|
2209
|
+
to enable parallel mode.
|
|
2210
|
+
|
|
2096
2211
|
Only dataset sink mode is supported for now.
|
|
2097
2212
|
|
|
2098
2213
|
.. warning::
|
|
@@ -2111,9 +2226,9 @@ class Model:
|
|
|
2111
2226
|
Configure pynative mode or CPU, the training process will be performed with
|
|
2112
2227
|
dataset not sink. Default: ``True`` .
|
|
2113
2228
|
sink_size (int): Control the number of steps for each sinking.
|
|
2229
|
+
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2114
2230
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
2115
2231
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
2116
|
-
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2117
2232
|
Default: ``-1`` .
|
|
2118
2233
|
|
|
2119
2234
|
Returns:
|
|
@@ -2127,10 +2242,10 @@ class Model:
|
|
|
2127
2242
|
>>> from mindspore import Tensor, nn
|
|
2128
2243
|
>>> from mindspore.train import Model
|
|
2129
2244
|
>>> from mindspore.communication import init
|
|
2245
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2130
2246
|
>>>
|
|
2131
2247
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2132
2248
|
>>> init()
|
|
2133
|
-
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
2134
2249
|
>>>
|
|
2135
2250
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
2136
2251
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
@@ -2138,13 +2253,15 @@ class Model:
|
|
|
2138
2253
|
>>> # Define the network structure of LeNet5. Refer to
|
|
2139
2254
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
2140
2255
|
>>> net = LeNet5()
|
|
2256
|
+
>>> parallel_net = AutoParallel(net)
|
|
2141
2257
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
2142
2258
|
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
|
2143
2259
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
2144
|
-
>>> model = Model(
|
|
2260
|
+
>>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
2145
2261
|
... loss_scale_manager=loss_scale_manager)
|
|
2146
2262
|
>>> layout_dict = model.infer_train_layout(dataset)
|
|
2147
2263
|
"""
|
|
2264
|
+
_init_auto_parallel_context(self._network)
|
|
2148
2265
|
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
|
2149
2266
|
|
|
2150
2267
|
train_dataset.__no_send__ = True
|
|
@@ -2156,11 +2273,13 @@ class Model:
|
|
|
2156
2273
|
train_network.compile(*inputs)
|
|
2157
2274
|
break
|
|
2158
2275
|
train_dataset.__model_hash__ = hash(self)
|
|
2276
|
+
_clear_auto_parallel_context(self._network)
|
|
2159
2277
|
return train_network.parameter_layout_dict
|
|
2160
2278
|
|
|
2161
2279
|
def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
|
|
2162
2280
|
"""
|
|
2163
|
-
Generate parameter layout for the predict network
|
|
2281
|
+
Generate parameter layout for the predict network when using `AutoParallel(cell)`
|
|
2282
|
+
to enable parallel mode.
|
|
2164
2283
|
|
|
2165
2284
|
Data could be a single tensor or multiple tensors.
|
|
2166
2285
|
|
|
@@ -2183,21 +2302,47 @@ class Model:
|
|
|
2183
2302
|
RuntimeError: If not in GRAPH_MODE.
|
|
2184
2303
|
|
|
2185
2304
|
Examples:
|
|
2186
|
-
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
|
2187
|
-
>>> # mindspore.cn.
|
|
2188
2305
|
>>> import numpy as np
|
|
2189
|
-
>>> import mindspore as
|
|
2306
|
+
>>> import mindspore.nn as nn
|
|
2190
2307
|
>>> from mindspore import Tensor
|
|
2191
2308
|
>>> from mindspore.train import Model
|
|
2309
|
+
>>> from mindspore.ops import operations as P
|
|
2310
|
+
>>> from mindspore import context
|
|
2192
2311
|
>>> from mindspore.communication import init
|
|
2312
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2313
|
+
>>>
|
|
2314
|
+
>>> class Net(nn.Cell):
|
|
2315
|
+
>>> def __init__(self):
|
|
2316
|
+
>>> super(Net, self).__init__()
|
|
2317
|
+
>>> self.fc1 = nn.Dense(128, 768, activation='relu')
|
|
2318
|
+
>>> self.fc2 = nn.Dense(128, 768, activation='relu')
|
|
2319
|
+
>>> self.fc3 = nn.Dense(128, 768, activation='relu')
|
|
2320
|
+
>>> self.fc4 = nn.Dense(768, 768, activation='relu')
|
|
2321
|
+
>>> self.relu4 = nn.ReLU()
|
|
2322
|
+
>>> self.relu5 = nn.ReLU()
|
|
2323
|
+
>>> self.transpose = P.Transpose()
|
|
2324
|
+
>>> self.matmul1 = P.MatMul()
|
|
2325
|
+
>>> self.matmul2 = P.MatMul()
|
|
2326
|
+
>>>
|
|
2327
|
+
>>> def construct(self, x):
|
|
2328
|
+
>>> q = self.fc1(x)
|
|
2329
|
+
>>> k = self.fc2(x)
|
|
2330
|
+
>>> v = self.fc3(x)
|
|
2331
|
+
>>> k = self.transpose(k, (1, 0))
|
|
2332
|
+
>>> c = self.relu4(self.matmul1(q, k))
|
|
2333
|
+
>>> s = self.relu5(self.matmul2(c, v))
|
|
2334
|
+
>>> s = self.fc4(s)
|
|
2335
|
+
>>> return s
|
|
2193
2336
|
>>>
|
|
2194
2337
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2195
2338
|
>>> init()
|
|
2196
|
-
>>>
|
|
2197
|
-
>>>
|
|
2198
|
-
>>>
|
|
2199
|
-
>>>
|
|
2339
|
+
>>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
|
2340
|
+
>>> net = Net()
|
|
2341
|
+
>>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
|
|
2342
|
+
>>> model = Model(parallel_net)
|
|
2343
|
+
>>> predict_map = model.infer_predict_layout(inputs)
|
|
2200
2344
|
"""
|
|
2345
|
+
_init_auto_parallel_context(self._network)
|
|
2201
2346
|
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2202
2347
|
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2203
2348
|
"only supports GRAPH MODE and Ascend target currently.")
|
|
@@ -2217,6 +2362,7 @@ class Model:
|
|
|
2217
2362
|
predict_net.phase = origin_phase
|
|
2218
2363
|
else:
|
|
2219
2364
|
predict_net.compile(*predict_data)
|
|
2365
|
+
_clear_auto_parallel_context(self._network)
|
|
2220
2366
|
return predict_net.parameter_layout_dict
|
|
2221
2367
|
|
|
2222
2368
|
def _flush_from_cache(self, cb_params):
|
|
@@ -2256,16 +2402,5 @@ class Model:
|
|
|
2256
2402
|
"""
|
|
2257
2403
|
return self._eval_network
|
|
2258
2404
|
|
|
2259
|
-
def _prepare_obf_dataset(self, dataset):
|
|
2260
|
-
if not hasattr(self._network, 'obf_ratios'):
|
|
2261
|
-
return dataset
|
|
2262
|
-
data_size = dataset.get_dataset_size()
|
|
2263
|
-
obf_ratio_dataset = []
|
|
2264
|
-
for _ in range(data_size):
|
|
2265
|
-
obf_ratio_dataset.append(self._network.obf_ratios)
|
|
2266
|
-
obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
|
|
2267
|
-
dataset = ds.zip((dataset, obf_ratio_dataset))
|
|
2268
|
-
return dataset
|
|
2269
|
-
|
|
2270
2405
|
|
|
2271
2406
|
__all__ = ["Model"]
|