mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -27,7 +27,6 @@ import time
|
|
|
27
27
|
import numpy as np
|
|
28
28
|
|
|
29
29
|
import mindspore
|
|
30
|
-
import mindspore.dataset as ds
|
|
31
30
|
from mindspore import log as logger
|
|
32
31
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
33
32
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
@@ -36,7 +35,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
35
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
36
|
from mindspore import _checkparam as Validator
|
|
38
37
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
|
|
38
|
+
TrainFaultTolerance
|
|
40
39
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
40
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
41
|
from mindspore import context
|
|
@@ -57,7 +56,9 @@ from mindspore.dataset.core.config import get_debug_mode
|
|
|
57
56
|
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
|
58
57
|
from mindspore.train import amp
|
|
59
58
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
|
+
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
60
|
|
|
61
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
61
62
|
|
|
62
63
|
def _transfer_tensor_to_tuple(inputs):
|
|
63
64
|
"""
|
|
@@ -91,6 +92,7 @@ def _save_final_ckpt(func):
|
|
|
91
92
|
"""
|
|
92
93
|
Decorator function, which saves the current checkpoint when an exception occurs during training.
|
|
93
94
|
"""
|
|
95
|
+
|
|
94
96
|
@wraps(func)
|
|
95
97
|
def wrapper(self, *args, **kwargs):
|
|
96
98
|
obj = None
|
|
@@ -107,7 +109,7 @@ def _save_final_ckpt(func):
|
|
|
107
109
|
# pylint: disable=W0212
|
|
108
110
|
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
|
|
109
111
|
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
|
|
110
|
-
|
|
112
|
+
+ str(self._current_step_num) + "_breakpoint.ckpt"
|
|
111
113
|
cur_file = os.path.join(obj._directory, cur_ckpoint_file)
|
|
112
114
|
if "epoch_num" in obj._append_dict:
|
|
113
115
|
obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
|
|
@@ -118,56 +120,82 @@ def _save_final_ckpt(func):
|
|
|
118
120
|
raise e
|
|
119
121
|
else:
|
|
120
122
|
func(self, *args, **kwargs)
|
|
123
|
+
|
|
121
124
|
return wrapper
|
|
122
125
|
|
|
126
|
+
|
|
127
|
+
def _handle_exception_info(obj, uce_env, tft, e):
|
|
128
|
+
"""handle exception info"""
|
|
129
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
130
|
+
if not uce_env:
|
|
131
|
+
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
132
|
+
exc_info=True)
|
|
133
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
134
|
+
raise e
|
|
135
|
+
e_str = str(e)
|
|
136
|
+
logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
137
|
+
if "UCEError" in e_str:
|
|
138
|
+
logger.info("uce wrapper report UCEError")
|
|
139
|
+
obj.is_uce_rank = True
|
|
140
|
+
# if error is HBM_MULTI_BIT_ECC_ERROR
|
|
141
|
+
if "error_code=507054" in e_str:
|
|
142
|
+
hbm_error_time, optimize_start, optimizer_end = _get_optimzer_timestamps()
|
|
143
|
+
can_repair = tft.tft_can_do_uce_repair(hbm_error_time, optimize_start, optimizer_end)
|
|
144
|
+
logger.info(f"UCEError of type HBM_MULTI_BIT_ECC_ERROR occurs, \
|
|
145
|
+
hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, \
|
|
146
|
+
optimizer_end={optimizer_end}, can_repair={can_repair}")
|
|
147
|
+
if not can_repair:
|
|
148
|
+
logger.error(f"Caught UCEError of type HBM_MULTI_BIT_ECC_ERROR but can not repair, "
|
|
149
|
+
f"hbm_error_time={hbm_error_time}, optimize_start={optimize_start}, "
|
|
150
|
+
f"optimizer_end={optimizer_end}", exc_info=True)
|
|
151
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
152
|
+
raise e
|
|
153
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
154
|
+
elif "ForceStopError" in e_str:
|
|
155
|
+
logger.warning("uce wrapper caught RuntimeError ForceStopError")
|
|
156
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
157
|
+
tft.tft_report_error(force_stop_err)
|
|
158
|
+
elif "ARF FINISH" in e_str:
|
|
159
|
+
logger.warning(f"ARF FINISH")
|
|
160
|
+
_set_recovery_context(is_arf=True)
|
|
161
|
+
tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
|
|
162
|
+
else:
|
|
163
|
+
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
164
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
165
|
+
raise e
|
|
166
|
+
|
|
167
|
+
|
|
123
168
|
def _handle_tft(func):
|
|
124
169
|
"""
|
|
125
170
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
126
171
|
"""
|
|
172
|
+
|
|
127
173
|
@wraps(func)
|
|
128
174
|
def wrapper(self, *args, **kwargs):
|
|
129
175
|
obj = None
|
|
130
|
-
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'),
|
|
176
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
|
|
131
177
|
obj = kwargs.get('callbacks')
|
|
132
178
|
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
133
179
|
for item in kwargs.get('callbacks'):
|
|
134
|
-
if isinstance(item,
|
|
180
|
+
if isinstance(item, TrainFaultTolerance):
|
|
135
181
|
obj = item
|
|
136
182
|
if obj:
|
|
137
183
|
tft = obj.tft
|
|
138
184
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
139
|
-
uce_env = "UCE:1" in tft_env
|
|
185
|
+
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
|
|
140
186
|
while True:
|
|
141
187
|
try:
|
|
142
188
|
return func(self, *args, **kwargs)
|
|
143
189
|
except RuntimeError as e:
|
|
144
|
-
|
|
145
|
-
if not uce_env:
|
|
146
|
-
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
147
|
-
exc_info=True)
|
|
148
|
-
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
149
|
-
raise e
|
|
150
|
-
e_str = str(e)
|
|
151
|
-
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
152
|
-
if "UCEError" in e_str:
|
|
153
|
-
logger.info("uce wrapper report UCEError")
|
|
154
|
-
obj.is_uce_rank = True
|
|
155
|
-
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
156
|
-
elif "ForceStopError" in e_str:
|
|
157
|
-
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
158
|
-
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
159
|
-
tft.tft_report_error(force_stop_err)
|
|
160
|
-
else:
|
|
161
|
-
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
162
|
-
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
163
|
-
raise e
|
|
190
|
+
_handle_exception_info(obj, uce_env, tft, e)
|
|
164
191
|
ret = tft.tft_wait_next_action()
|
|
165
192
|
if ret == tft.Action.EXIT.value:
|
|
166
193
|
raise e
|
|
167
194
|
repair_step = tft.tft_get_repair_step()
|
|
168
|
-
logger.
|
|
169
|
-
{}".format(repair_step,
|
|
170
|
-
|
|
195
|
+
logger.warning(
|
|
196
|
+
"uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
|
|
197
|
+
self.batch_num))
|
|
198
|
+
initial_epoch = int(repair_step / self.batch_num)
|
|
171
199
|
initial_step = repair_step % self.batch_num
|
|
172
200
|
kwargs["initial_epoch"] = initial_epoch
|
|
173
201
|
|
|
@@ -190,9 +218,9 @@ def _handle_tft(func):
|
|
|
190
218
|
kwargs["initial_step"] = cb_initial_step
|
|
191
219
|
# reset all accu grads to zero
|
|
192
220
|
obj._reset_acc_grads()
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
221
|
+
logger.warning(
|
|
222
|
+
"uce wrapper repair complete initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch,
|
|
223
|
+
cb_initial_step))
|
|
196
224
|
continue
|
|
197
225
|
except BaseException as e:
|
|
198
226
|
logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
|
|
@@ -200,6 +228,7 @@ initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
|
200
228
|
raise e
|
|
201
229
|
else:
|
|
202
230
|
return func(self, *args, **kwargs)
|
|
231
|
+
|
|
203
232
|
return wrapper
|
|
204
233
|
|
|
205
234
|
|
|
@@ -216,7 +245,7 @@ def _check_tft():
|
|
|
216
245
|
if ms_mode != mindspore.GRAPH_MODE:
|
|
217
246
|
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
218
247
|
jit_level = context.get_context("jit_level")
|
|
219
|
-
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
248
|
+
if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
|
|
220
249
|
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
221
250
|
|
|
222
251
|
|
|
@@ -406,12 +435,13 @@ class Model:
|
|
|
406
435
|
the accuracy is reduced by less than 3%.
|
|
407
436
|
|
|
408
437
|
If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
|
|
409
|
-
In order for this function to work, you need to set the optimizer
|
|
410
|
-
at the
|
|
438
|
+
In order for this function to work, you need to set the parameter `optimizer`, along with
|
|
439
|
+
at least one of the parameter `eval_network` or performance `metrics`.
|
|
411
440
|
|
|
412
441
|
Notice: The current optimization enabled by default only applies to some networks, and not all networks
|
|
413
442
|
can obtain the same benefits. It is recommended to enable this function on
|
|
414
|
-
the Graph mode + Ascend platform, and for better acceleration,
|
|
443
|
+
the Graph mode + Ascend platform, and for better acceleration,
|
|
444
|
+
refer to :class:`mindspore.boost.AutoBoost` to configure
|
|
415
445
|
boost_config_dict.
|
|
416
446
|
|
|
417
447
|
Examples:
|
|
@@ -436,6 +466,7 @@ class Model:
|
|
|
436
466
|
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
|
|
437
467
|
amp_level="O0", boost_level="O0", **kwargs):
|
|
438
468
|
self._network = network
|
|
469
|
+
_init_auto_parallel_context(self._network)
|
|
439
470
|
self._loss_fn = loss_fn
|
|
440
471
|
self._optimizer = optimizer
|
|
441
472
|
self._loss_scale_manager = None
|
|
@@ -470,6 +501,7 @@ class Model:
|
|
|
470
501
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
471
502
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
472
503
|
self.batch_num = -1
|
|
504
|
+
_clear_auto_parallel_context(self._network)
|
|
473
505
|
|
|
474
506
|
def _check_for_graph_cell(self, kwargs):
|
|
475
507
|
"""Check for graph cell"""
|
|
@@ -765,7 +797,7 @@ class Model:
|
|
|
765
797
|
break
|
|
766
798
|
logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
|
|
767
799
|
|
|
768
|
-
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
800
|
+
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
769
801
|
"""
|
|
770
802
|
Initialize compute graphs and data graphs with the sink mode.
|
|
771
803
|
|
|
@@ -794,7 +826,6 @@ class Model:
|
|
|
794
826
|
if not isinstance(train_dataset, mindspore.dataset.Dataset):
|
|
795
827
|
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
796
828
|
"but got {}.".format(type(train_dataset)))
|
|
797
|
-
|
|
798
829
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
799
830
|
"Begin to check parameter broadcast in model.build().")
|
|
800
831
|
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
@@ -807,23 +838,24 @@ class Model:
|
|
|
807
838
|
train_dataset.__no_send__ = True
|
|
808
839
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
809
840
|
dataset=train_dataset,
|
|
810
|
-
dataset_sink_mode=
|
|
841
|
+
dataset_sink_mode=sink_mode,
|
|
811
842
|
sink_size=sink_size)
|
|
812
843
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
|
|
813
|
-
|
|
814
|
-
|
|
844
|
+
if sink_mode:
|
|
845
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
846
|
+
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
815
847
|
|
|
816
|
-
|
|
817
|
-
|
|
848
|
+
# Since dataset pipeline has been triggered, delete flag
|
|
849
|
+
delattr(train_dataset, "__no_send__")
|
|
818
850
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
851
|
+
# Waiting for the dataset warmup ready
|
|
852
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
853
|
+
"Begin waiting for dataset warmup in model.build().")
|
|
854
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
855
|
+
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
856
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
857
|
+
"The dataset warmup was successful in model.build().")
|
|
858
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
827
859
|
|
|
828
860
|
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
829
861
|
train_network.add_flags_recursive(is_first_iteration=True)
|
|
@@ -833,6 +865,7 @@ class Model:
|
|
|
833
865
|
logger.info("Begin to compile train network in model.build() procedure.")
|
|
834
866
|
train_network.compile(*inputs)
|
|
835
867
|
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
868
|
+
train_dataset.reset()
|
|
836
869
|
break
|
|
837
870
|
|
|
838
871
|
if valid_dataset:
|
|
@@ -846,7 +879,7 @@ class Model:
|
|
|
846
879
|
valid_dataset.__no_send__ = True
|
|
847
880
|
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
|
848
881
|
dataset=valid_dataset,
|
|
849
|
-
dataset_sink_mode=
|
|
882
|
+
dataset_sink_mode=sink_mode)
|
|
850
883
|
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
851
884
|
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
852
885
|
for inputs in valid_dataset_helper:
|
|
@@ -854,6 +887,7 @@ class Model:
|
|
|
854
887
|
"Begin to compile eval network in model.build().")
|
|
855
888
|
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
856
889
|
eval_network.compile(*inputs)
|
|
890
|
+
valid_dataset.reset()
|
|
857
891
|
break
|
|
858
892
|
|
|
859
893
|
@staticmethod
|
|
@@ -922,6 +956,7 @@ class Model:
|
|
|
922
956
|
cb_params.last_save_ckpt_step = None
|
|
923
957
|
cb_params.latest_ckpt_file = None
|
|
924
958
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
959
|
+
cb_params.is_arf = _get_recovery_context("is_arf")
|
|
925
960
|
|
|
926
961
|
# build callback list
|
|
927
962
|
with _CallbackManager(callbacks) as list_callback:
|
|
@@ -1026,6 +1061,9 @@ class Model:
|
|
|
1026
1061
|
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1027
1062
|
if need_exec_callback_step_end:
|
|
1028
1063
|
list_callback.on_train_step_end(run_context)
|
|
1064
|
+
if cb_params.is_arf:
|
|
1065
|
+
cb_params.is_arf = False
|
|
1066
|
+
_set_recovery_context(is_arf=False)
|
|
1029
1067
|
|
|
1030
1068
|
# Embedding cache server only run one step.
|
|
1031
1069
|
if is_embedding_cache_server:
|
|
@@ -1056,7 +1094,7 @@ class Model:
|
|
|
1056
1094
|
if should_stop:
|
|
1057
1095
|
break
|
|
1058
1096
|
|
|
1059
|
-
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
|
|
1097
|
+
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
|
|
1060
1098
|
and not _get_recovery_context("latest_ckpt_file")
|
|
1061
1099
|
self.epoch_iter += 1
|
|
1062
1100
|
if need_reset_to_beginning:
|
|
@@ -1100,7 +1138,7 @@ class Model:
|
|
|
1100
1138
|
Check whether enable recovery and execution mode consistency.
|
|
1101
1139
|
"""
|
|
1102
1140
|
|
|
1103
|
-
enable_recovery = _get_recovery_context("enable_recovery")
|
|
1141
|
+
enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
|
|
1104
1142
|
if not enable_recovery:
|
|
1105
1143
|
self.enable_recovery = False
|
|
1106
1144
|
else:
|
|
@@ -1117,6 +1155,8 @@ class Model:
|
|
|
1117
1155
|
dataset_size (int): The number of batches in a dataset.
|
|
1118
1156
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1119
1157
|
"""
|
|
1158
|
+
if context.get_context("device_target") != "GPU":
|
|
1159
|
+
return
|
|
1120
1160
|
if not self.enable_recovery:
|
|
1121
1161
|
self.need_load_ckpt = False
|
|
1122
1162
|
|
|
@@ -1145,7 +1185,7 @@ class Model:
|
|
|
1145
1185
|
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1146
1186
|
except BaseException as e:
|
|
1147
1187
|
os.remove(cb_params.latest_ckpt_file)
|
|
1148
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1188
|
+
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
|
|
1149
1189
|
+ cb_params.latest_ckpt_file) from e
|
|
1150
1190
|
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1151
1191
|
self.need_load_ckpt = False
|
|
@@ -1235,6 +1275,9 @@ class Model:
|
|
|
1235
1275
|
self._loss_scale_manager.update_loss_scale(overflow)
|
|
1236
1276
|
|
|
1237
1277
|
list_callback.on_train_step_end(run_context)
|
|
1278
|
+
if cb_params.is_arf:
|
|
1279
|
+
cb_params.is_arf = False
|
|
1280
|
+
_set_recovery_context(is_arf=False)
|
|
1238
1281
|
# Embedding cache server only run one step.
|
|
1239
1282
|
if is_embedding_cache_server:
|
|
1240
1283
|
break
|
|
@@ -1332,10 +1375,9 @@ class Model:
|
|
|
1332
1375
|
... loss_scale_manager=loss_scale_manager)
|
|
1333
1376
|
>>> model.train(2, dataset)
|
|
1334
1377
|
"""
|
|
1378
|
+
_init_auto_parallel_context(self._network)
|
|
1335
1379
|
_check_tft()
|
|
1336
1380
|
device_target = context.get_context("device_target")
|
|
1337
|
-
# prepare dataset for obfuscated model
|
|
1338
|
-
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1339
1381
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1340
1382
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1341
1383
|
dataset_sink_mode = False
|
|
@@ -1391,6 +1433,8 @@ class Model:
|
|
|
1391
1433
|
if _enable_distributed_mindrt():
|
|
1392
1434
|
_reset_op_id_with_offset()
|
|
1393
1435
|
|
|
1436
|
+
_clear_auto_parallel_context(self._network)
|
|
1437
|
+
|
|
1394
1438
|
@staticmethod
|
|
1395
1439
|
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
|
|
1396
1440
|
if get_debug_mode() and dataset_sink_mode:
|
|
@@ -1484,11 +1528,8 @@ class Model:
|
|
|
1484
1528
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1485
1529
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
|
1486
1530
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1487
|
-
|
|
1488
|
-
Tutorial Examples:
|
|
1489
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1490
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1491
1531
|
"""
|
|
1532
|
+
_init_auto_parallel_context(self._network)
|
|
1492
1533
|
device_target = context.get_context("device_target")
|
|
1493
1534
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1494
1535
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
@@ -1540,8 +1581,9 @@ class Model:
|
|
|
1540
1581
|
valid_dataset=valid_dataset,
|
|
1541
1582
|
valid_frequency=valid_frequency,
|
|
1542
1583
|
valid_dataset_sink_mode=valid_dataset_sink_mode)
|
|
1584
|
+
_clear_auto_parallel_context(self._network)
|
|
1543
1585
|
|
|
1544
|
-
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
1586
|
+
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, sink_mode=True):
|
|
1545
1587
|
"""
|
|
1546
1588
|
Build computational graphs and data graphs with the sink mode.
|
|
1547
1589
|
|
|
@@ -1560,6 +1602,7 @@ class Model:
|
|
|
1560
1602
|
will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
|
|
1561
1603
|
sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
|
|
1562
1604
|
epoch (int): Control the training epochs. Default: ``1`` .
|
|
1605
|
+
sink_mode (bool): Determines whether to pass the data through dataset channel. Default: ``True`` .
|
|
1563
1606
|
|
|
1564
1607
|
Examples:
|
|
1565
1608
|
>>> from mindspore import nn
|
|
@@ -1580,16 +1623,18 @@ class Model:
|
|
|
1580
1623
|
>>> model.build(dataset, epoch=2)
|
|
1581
1624
|
>>> model.train(2, dataset)
|
|
1582
1625
|
"""
|
|
1626
|
+
_init_auto_parallel_context(self._network)
|
|
1583
1627
|
epoch = Validator.check_positive_int(epoch)
|
|
1584
1628
|
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1585
1629
|
self._train_network.check_names_and_refresh_name()
|
|
1586
1630
|
self._train_network._is_check_and_refresh = True
|
|
1587
1631
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
|
|
1588
1632
|
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1589
|
-
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1633
|
+
self._init(train_dataset, valid_dataset, sink_size, epoch, sink_mode)
|
|
1590
1634
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1591
1635
|
"The model.build() which contains dataset warmup and network compile is success.")
|
|
1592
1636
|
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1637
|
+
_clear_auto_parallel_context(self._network)
|
|
1593
1638
|
|
|
1594
1639
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1595
1640
|
"""
|
|
@@ -1759,12 +1804,8 @@ class Model:
|
|
|
1759
1804
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1760
1805
|
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
|
1761
1806
|
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
|
1762
|
-
|
|
1763
|
-
Tutorial Examples:
|
|
1764
|
-
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1765
|
-
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1766
1807
|
"""
|
|
1767
|
-
|
|
1808
|
+
_init_auto_parallel_context(self._network)
|
|
1768
1809
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1769
1810
|
|
|
1770
1811
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
@@ -1809,6 +1850,7 @@ class Model:
|
|
|
1809
1850
|
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1810
1851
|
if _enable_distributed_mindrt():
|
|
1811
1852
|
_reset_op_id_with_offset()
|
|
1853
|
+
_clear_auto_parallel_context(self._network)
|
|
1812
1854
|
|
|
1813
1855
|
return eval_result
|
|
1814
1856
|
|
|
@@ -1821,7 +1863,8 @@ class Model:
|
|
|
1821
1863
|
The predict data, can be a single tensor,
|
|
1822
1864
|
a list of tensor, or a tuple of tensor.
|
|
1823
1865
|
|
|
1824
|
-
config (dict, optional)
|
|
1866
|
+
config (dict, optional): The config parameter is enabled when the backend is ‘lite’.
|
|
1867
|
+
|
|
1825
1868
|
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1826
1869
|
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1827
1870
|
table file for inference. The content of the configuration file is as follows:
|
|
@@ -1831,6 +1874,16 @@ class Model:
|
|
|
1831
1874
|
For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
|
|
1832
1875
|
config.ini file:
|
|
1833
1876
|
|
|
1877
|
+
The config has 3 forms:
|
|
1878
|
+
1. configPath defines the path of the configuration file, which is used to pass user-defined
|
|
1879
|
+
options during model building. Default value: ``"" ``.
|
|
1880
|
+
|
|
1881
|
+
.. code-block::
|
|
1882
|
+
|
|
1883
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1884
|
+
|
|
1885
|
+
Here is the content of the config.ini file:
|
|
1886
|
+
|
|
1834
1887
|
.. code-block::
|
|
1835
1888
|
|
|
1836
1889
|
[ascend_context]
|
|
@@ -1839,20 +1892,15 @@ class Model:
|
|
|
1839
1892
|
[op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
|
|
1840
1893
|
[op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
|
|
1841
1894
|
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
.. code-block::
|
|
1845
|
-
|
|
1846
|
-
config = {"configPath" : "/home/user/config.ini"}
|
|
1847
|
-
|
|
1848
|
-
When only the config_dict is configured, it is done as follows:
|
|
1895
|
+
2. Set the user-defined options in parameter dictionary, it is done as follows:
|
|
1849
1896
|
|
|
1850
1897
|
.. code-block::
|
|
1851
1898
|
|
|
1852
1899
|
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1853
1900
|
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1854
1901
|
|
|
1855
|
-
|
|
1902
|
+
3. Both the `configPath` and the `parameter dictionary` are configured, The priority of the parameter
|
|
1903
|
+
dictionary is higher than that of the content in the configuration file. It is done as follows:
|
|
1856
1904
|
|
|
1857
1905
|
.. code-block::
|
|
1858
1906
|
|
|
@@ -1860,12 +1908,13 @@ class Model:
|
|
|
1860
1908
|
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1861
1909
|
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1862
1910
|
|
|
1863
|
-
Note that
|
|
1864
|
-
|
|
1911
|
+
Note that in the "configPath" the parameter is set as "rank_table_file = [path_a]", but in dict is set
|
|
1912
|
+
as "ascend_context" : {"rank_table_file" : "path_b"}, in this case, the path_b takes precedence.
|
|
1865
1913
|
|
|
1866
1914
|
Returns:
|
|
1867
1915
|
Tensor, array(s) of predictions.
|
|
1868
1916
|
"""
|
|
1917
|
+
|
|
1869
1918
|
def _get_lite_context(lite_context_input):
|
|
1870
1919
|
# use default lite context parameters for now
|
|
1871
1920
|
device_target = context.get_context("device_target").lower()
|
|
@@ -1899,7 +1948,7 @@ class Model:
|
|
|
1899
1948
|
if not self._mindspore_lite:
|
|
1900
1949
|
self._mindspore_lite = importlib.import_module('mindspore_lite')
|
|
1901
1950
|
|
|
1902
|
-
use_past = False
|
|
1951
|
+
use_past = False # default execute full model inference
|
|
1903
1952
|
model_group_id = None
|
|
1904
1953
|
if self._predict_network.get_flags().__contains__("is_first_iteration"):
|
|
1905
1954
|
is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
|
|
@@ -2012,6 +2061,7 @@ class Model:
|
|
|
2012
2061
|
>>> model = Model(LeNet5())
|
|
2013
2062
|
>>> result = model.predict(input_data)
|
|
2014
2063
|
"""
|
|
2064
|
+
_init_auto_parallel_context(self._network)
|
|
2015
2065
|
if backend not in ['lite', None]:
|
|
2016
2066
|
raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
|
|
2017
2067
|
if backend == "lite" and self._lite_infer:
|
|
@@ -2027,6 +2077,7 @@ class Model:
|
|
|
2027
2077
|
except BaseException as e:
|
|
2028
2078
|
self._lite_infer = False
|
|
2029
2079
|
logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
|
|
2080
|
+
_clear_auto_parallel_context(self._network)
|
|
2030
2081
|
|
|
2031
2082
|
def _check_input_data():
|
|
2032
2083
|
"""Input data check."""
|
|
@@ -2092,7 +2143,9 @@ class Model:
|
|
|
2092
2143
|
|
|
2093
2144
|
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
|
2094
2145
|
"""
|
|
2095
|
-
Generate parameter layout for the train network
|
|
2146
|
+
Generate parameter layout for the train network when using `AutoParallel(cell)`
|
|
2147
|
+
to enable parallel mode.
|
|
2148
|
+
|
|
2096
2149
|
Only dataset sink mode is supported for now.
|
|
2097
2150
|
|
|
2098
2151
|
.. warning::
|
|
@@ -2111,9 +2164,9 @@ class Model:
|
|
|
2111
2164
|
Configure pynative mode or CPU, the training process will be performed with
|
|
2112
2165
|
dataset not sink. Default: ``True`` .
|
|
2113
2166
|
sink_size (int): Control the number of steps for each sinking.
|
|
2167
|
+
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2114
2168
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
2115
2169
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
2116
|
-
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2117
2170
|
Default: ``-1`` .
|
|
2118
2171
|
|
|
2119
2172
|
Returns:
|
|
@@ -2127,10 +2180,10 @@ class Model:
|
|
|
2127
2180
|
>>> from mindspore import Tensor, nn
|
|
2128
2181
|
>>> from mindspore.train import Model
|
|
2129
2182
|
>>> from mindspore.communication import init
|
|
2183
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2130
2184
|
>>>
|
|
2131
2185
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2132
2186
|
>>> init()
|
|
2133
|
-
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
2134
2187
|
>>>
|
|
2135
2188
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
2136
2189
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
@@ -2138,13 +2191,15 @@ class Model:
|
|
|
2138
2191
|
>>> # Define the network structure of LeNet5. Refer to
|
|
2139
2192
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
2140
2193
|
>>> net = LeNet5()
|
|
2194
|
+
>>> parallel_net = AutoParallel(net)
|
|
2141
2195
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
2142
2196
|
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
|
2143
2197
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
2144
|
-
>>> model = Model(
|
|
2198
|
+
>>> model = Model(parallel_net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
2145
2199
|
... loss_scale_manager=loss_scale_manager)
|
|
2146
2200
|
>>> layout_dict = model.infer_train_layout(dataset)
|
|
2147
2201
|
"""
|
|
2202
|
+
_init_auto_parallel_context(self._network)
|
|
2148
2203
|
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
|
2149
2204
|
|
|
2150
2205
|
train_dataset.__no_send__ = True
|
|
@@ -2156,11 +2211,13 @@ class Model:
|
|
|
2156
2211
|
train_network.compile(*inputs)
|
|
2157
2212
|
break
|
|
2158
2213
|
train_dataset.__model_hash__ = hash(self)
|
|
2214
|
+
_clear_auto_parallel_context(self._network)
|
|
2159
2215
|
return train_network.parameter_layout_dict
|
|
2160
2216
|
|
|
2161
2217
|
def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
|
|
2162
2218
|
"""
|
|
2163
|
-
Generate parameter layout for the predict network
|
|
2219
|
+
Generate parameter layout for the predict network when using `AutoParallel(cell)`
|
|
2220
|
+
to enable parallel mode.
|
|
2164
2221
|
|
|
2165
2222
|
Data could be a single tensor or multiple tensors.
|
|
2166
2223
|
|
|
@@ -2183,21 +2240,47 @@ class Model:
|
|
|
2183
2240
|
RuntimeError: If not in GRAPH_MODE.
|
|
2184
2241
|
|
|
2185
2242
|
Examples:
|
|
2186
|
-
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
|
2187
|
-
>>> # mindspore.cn.
|
|
2188
2243
|
>>> import numpy as np
|
|
2189
|
-
>>> import mindspore as
|
|
2244
|
+
>>> import mindspore.nn as nn
|
|
2190
2245
|
>>> from mindspore import Tensor
|
|
2191
2246
|
>>> from mindspore.train import Model
|
|
2247
|
+
>>> from mindspore.ops import operations as P
|
|
2248
|
+
>>> from mindspore import context
|
|
2192
2249
|
>>> from mindspore.communication import init
|
|
2250
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2251
|
+
>>>
|
|
2252
|
+
>>> class Net(nn.Cell):
|
|
2253
|
+
>>> def __init__(self):
|
|
2254
|
+
>>> super(Net, self).__init__()
|
|
2255
|
+
>>> self.fc1 = nn.Dense(128, 768, activation='relu')
|
|
2256
|
+
>>> self.fc2 = nn.Dense(128, 768, activation='relu')
|
|
2257
|
+
>>> self.fc3 = nn.Dense(128, 768, activation='relu')
|
|
2258
|
+
>>> self.fc4 = nn.Dense(768, 768, activation='relu')
|
|
2259
|
+
>>> self.relu4 = nn.ReLU()
|
|
2260
|
+
>>> self.relu5 = nn.ReLU()
|
|
2261
|
+
>>> self.transpose = P.Transpose()
|
|
2262
|
+
>>> self.matmul1 = P.MatMul()
|
|
2263
|
+
>>> self.matmul2 = P.MatMul()
|
|
2264
|
+
>>>
|
|
2265
|
+
>>> def construct(self, x):
|
|
2266
|
+
>>> q = self.fc1(x)
|
|
2267
|
+
>>> k = self.fc2(x)
|
|
2268
|
+
>>> v = self.fc3(x)
|
|
2269
|
+
>>> k = self.transpose(k, (1, 0))
|
|
2270
|
+
>>> c = self.relu4(self.matmul1(q, k))
|
|
2271
|
+
>>> s = self.relu5(self.matmul2(c, v))
|
|
2272
|
+
>>> s = self.fc4(s)
|
|
2273
|
+
>>> return s
|
|
2193
2274
|
>>>
|
|
2194
2275
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2195
2276
|
>>> init()
|
|
2196
|
-
>>>
|
|
2197
|
-
>>>
|
|
2198
|
-
>>>
|
|
2199
|
-
>>>
|
|
2277
|
+
>>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
|
2278
|
+
>>> net = Net()
|
|
2279
|
+
>>> parallel_net = AutoParallel(net, parallel_mode='semi_auto')
|
|
2280
|
+
>>> model = Model(parallel_net)
|
|
2281
|
+
>>> predict_map = model.infer_predict_layout(inputs)
|
|
2200
2282
|
"""
|
|
2283
|
+
_init_auto_parallel_context(self._network)
|
|
2201
2284
|
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2202
2285
|
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2203
2286
|
"only supports GRAPH MODE and Ascend target currently.")
|
|
@@ -2217,6 +2300,7 @@ class Model:
|
|
|
2217
2300
|
predict_net.phase = origin_phase
|
|
2218
2301
|
else:
|
|
2219
2302
|
predict_net.compile(*predict_data)
|
|
2303
|
+
_clear_auto_parallel_context(self._network)
|
|
2220
2304
|
return predict_net.parameter_layout_dict
|
|
2221
2305
|
|
|
2222
2306
|
def _flush_from_cache(self, cb_params):
|
|
@@ -2256,16 +2340,5 @@ class Model:
|
|
|
2256
2340
|
"""
|
|
2257
2341
|
return self._eval_network
|
|
2258
2342
|
|
|
2259
|
-
def _prepare_obf_dataset(self, dataset):
|
|
2260
|
-
if not hasattr(self._network, 'obf_ratios'):
|
|
2261
|
-
return dataset
|
|
2262
|
-
data_size = dataset.get_dataset_size()
|
|
2263
|
-
obf_ratio_dataset = []
|
|
2264
|
-
for _ in range(data_size):
|
|
2265
|
-
obf_ratio_dataset.append(self._network.obf_ratios)
|
|
2266
|
-
obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
|
|
2267
|
-
dataset = ds.zip((dataset, obf_ratio_dataset))
|
|
2268
|
-
return dataset
|
|
2269
|
-
|
|
2270
2343
|
|
|
2271
2344
|
__all__ = ["Model"]
|