mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
import gc
|
|
20
21
|
import types
|
|
21
22
|
import sys
|
|
22
23
|
import os
|
|
@@ -24,11 +25,11 @@ import time
|
|
|
24
25
|
import ast
|
|
25
26
|
import inspect
|
|
26
27
|
import importlib
|
|
27
|
-
import hashlib
|
|
28
28
|
import contextlib
|
|
29
|
+
import json
|
|
29
30
|
from collections import OrderedDict, namedtuple
|
|
30
31
|
from functools import wraps
|
|
31
|
-
import
|
|
32
|
+
from typing import Optional, Callable
|
|
32
33
|
import mindspore as ms
|
|
33
34
|
from mindspore import context
|
|
34
35
|
from mindspore import log as logger
|
|
@@ -39,21 +40,23 @@ from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
|
39
40
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
41
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
42
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
42
|
-
from mindspore._c_expression import GraphExecutor_,
|
|
43
|
+
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
43
44
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
44
|
-
_ms_memory_recycle, _bind_device_ctx, StubNode
|
|
45
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
|
|
45
46
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
46
|
-
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast,
|
|
47
|
-
|
|
47
|
+
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
|
+
_is_parallel_mode
|
|
48
49
|
from mindspore import _checkparam as Validator
|
|
49
50
|
from mindspore._checkparam import is_stub_tensor
|
|
50
51
|
from mindspore.common._utils import is_shape_unknown
|
|
51
52
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
52
|
-
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
53
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
54
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
55
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
56
|
from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
|
|
57
|
+
from mindspore.common.jit_context import jit_context
|
|
58
|
+
from mindspore.common.jit_trace import _jit_trace
|
|
59
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
57
60
|
|
|
58
61
|
# Store ms_function class compiled pipeline cache.
|
|
59
62
|
ms_compile_cache = set()
|
|
@@ -107,8 +110,7 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
107
110
|
logger.info(f"The {echo_function_name} has been compiled again. "
|
|
108
111
|
f"{tips} ")
|
|
109
112
|
else:
|
|
110
|
-
tips = "Try to
|
|
111
|
-
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
113
|
+
tips = "Try to reuse the function object decorated by @jit to reduce the compile time. " \
|
|
112
114
|
"For more details, get instructions about `jit` at " \
|
|
113
115
|
"https://www.mindspore.cn/search?inputValue=jit."
|
|
114
116
|
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
@@ -120,14 +122,6 @@ def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time,
|
|
|
120
122
|
function_phases[full_function_name].add(create_time)
|
|
121
123
|
|
|
122
124
|
|
|
123
|
-
def _ms_adapter_tensor_as_parameter_output(data):
|
|
124
|
-
"""Check whether the data is an output from a parameter which is a ms_adapter tensor.
|
|
125
|
-
Pylint: disable=unidiomatic-typecheck.
|
|
126
|
-
"""
|
|
127
|
-
return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
|
|
128
|
-
and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
|
|
129
|
-
|
|
130
|
-
|
|
131
125
|
def _convert_python_data(data):
|
|
132
126
|
"""
|
|
133
127
|
Convert C++ data to python.
|
|
@@ -138,18 +132,8 @@ def _convert_python_data(data):
|
|
|
138
132
|
Returns:
|
|
139
133
|
data, a data convert C++ to python
|
|
140
134
|
"""
|
|
141
|
-
if isinstance(data,
|
|
142
|
-
return
|
|
143
|
-
if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
|
|
144
|
-
return data.tensor
|
|
145
|
-
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
|
146
|
-
return PythonTensor(data, internal=True)
|
|
147
|
-
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
-
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
-
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
-
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
-
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
-
return PythonRowTensor(row_tensor=data)
|
|
135
|
+
if isinstance(data, PythonTensor):
|
|
136
|
+
return data
|
|
153
137
|
if isinstance(data, StubNode):
|
|
154
138
|
return ms.common._stub_tensor._convert_stub(data)
|
|
155
139
|
if data.__class__ is tuple:
|
|
@@ -160,6 +144,12 @@ def _convert_python_data(data):
|
|
|
160
144
|
fields = data_dict.keys()
|
|
161
145
|
return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
|
|
162
146
|
return tuple(_convert_python_data(x) for x in data)
|
|
147
|
+
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
148
|
+
return PythonCSRTensor(csr_tensor=data)
|
|
149
|
+
if isinstance(data, COOTensor) and not isinstance(data, PythonCOOTensor):
|
|
150
|
+
return PythonCOOTensor(coo_tensor=data)
|
|
151
|
+
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
152
|
+
return PythonRowTensor(row_tensor=data)
|
|
163
153
|
if data.__class__ is list:
|
|
164
154
|
# Keep list object not change for inplace operation.
|
|
165
155
|
for i in range(len(data)):
|
|
@@ -578,11 +568,11 @@ def _get_hook_key(*args, **kwargs):
|
|
|
578
568
|
return hook_key
|
|
579
569
|
|
|
580
570
|
|
|
581
|
-
class
|
|
571
|
+
class _JitExecutor:
|
|
582
572
|
"""
|
|
583
573
|
Represents a function compiled by graph compiler.
|
|
584
574
|
|
|
585
|
-
|
|
575
|
+
_JitExecutor will compile the original function for every combination
|
|
586
576
|
of argument types and shapes it is given (as well as their values, optionally).
|
|
587
577
|
|
|
588
578
|
Args:
|
|
@@ -596,7 +586,7 @@ class _MindsporeFunctionExecutor:
|
|
|
596
586
|
The result of pipeline running in graph mode.
|
|
597
587
|
"""
|
|
598
588
|
|
|
599
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
|
|
589
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
|
|
600
590
|
init_pipeline()
|
|
601
591
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
602
592
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
@@ -608,13 +598,19 @@ class _MindsporeFunctionExecutor:
|
|
|
608
598
|
self.obj = obj
|
|
609
599
|
self.shard_parent_obj = obj
|
|
610
600
|
self.enable_tuple_broaden = False
|
|
611
|
-
|
|
601
|
+
if _run_jit_pipeline():
|
|
602
|
+
self._graph_executor = JitExecutor_.get_instance()
|
|
603
|
+
else:
|
|
604
|
+
self._graph_executor = GraphExecutor_.get_instance()
|
|
612
605
|
self._create_time = ms_create_time
|
|
613
606
|
self._compile_args = None
|
|
607
|
+
self._enable_auto_dynamic = dynamic == 1
|
|
614
608
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
615
609
|
|
|
616
610
|
@_wrap_func
|
|
617
611
|
def __call__(self, *args, **kwargs):
|
|
612
|
+
if jit_context() and jit_context().is_nested():
|
|
613
|
+
return jit_context().run_graph("", None, *())
|
|
618
614
|
args_list = args
|
|
619
615
|
if self.obj is not None:
|
|
620
616
|
args_list = args_list[1:]
|
|
@@ -634,10 +630,14 @@ class _MindsporeFunctionExecutor:
|
|
|
634
630
|
return None
|
|
635
631
|
|
|
636
632
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
637
|
-
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
633
|
+
if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
|
|
638
634
|
output = _pynative_executor.grad_jit(*new_inputs)
|
|
639
635
|
else:
|
|
640
636
|
output = self._graph_executor(tuple(new_inputs), phase)
|
|
637
|
+
if jit_context():
|
|
638
|
+
if is_stub_tensor(output):
|
|
639
|
+
output = output.stub_sync()
|
|
640
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
641
641
|
|
|
642
642
|
return output
|
|
643
643
|
|
|
@@ -653,7 +653,8 @@ class _MindsporeFunctionExecutor:
|
|
|
653
653
|
compile_args = self._generate_compile_args(args)
|
|
654
654
|
key_id = self._get_key_id()
|
|
655
655
|
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
|
|
656
|
-
self.input_signature
|
|
656
|
+
self.input_signature,
|
|
657
|
+
self._enable_auto_dynamic)
|
|
657
658
|
|
|
658
659
|
# Add mutable for compile_args for two scene:
|
|
659
660
|
# 1) Origin args is mutable.
|
|
@@ -704,7 +705,7 @@ class _MindsporeFunctionExecutor:
|
|
|
704
705
|
|
|
705
706
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
706
707
|
|
|
707
|
-
if phase in ms_compile_cache and not parameter_hook_updated():
|
|
708
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
|
|
708
709
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
709
710
|
# generated in generate_arguments_key.
|
|
710
711
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -726,7 +727,7 @@ class _MindsporeFunctionExecutor:
|
|
|
726
727
|
setattr(self.fn.__func__, "__jit_function__", True)
|
|
727
728
|
else:
|
|
728
729
|
setattr(self.fn, "__jit_function__", True)
|
|
729
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase
|
|
730
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
|
|
730
731
|
if isinstance(self.fn, types.MethodType):
|
|
731
732
|
delattr(self.fn.__func__, "__jit_function__")
|
|
732
733
|
else:
|
|
@@ -734,7 +735,7 @@ class _MindsporeFunctionExecutor:
|
|
|
734
735
|
else:
|
|
735
736
|
if isinstance(self.obj, ms.nn.Cell):
|
|
736
737
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
737
|
-
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase
|
|
738
|
+
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
|
|
738
739
|
|
|
739
740
|
if not is_compile:
|
|
740
741
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -760,7 +761,7 @@ class _MindsporeFunctionExecutor:
|
|
|
760
761
|
else:
|
|
761
762
|
key_id = str(id(self.obj)) + str(self._create_time)
|
|
762
763
|
|
|
763
|
-
if _pynative_executor.
|
|
764
|
+
if _pynative_executor.requires_grad():
|
|
764
765
|
key_id = key_id + ".grad"
|
|
765
766
|
return key_id
|
|
766
767
|
|
|
@@ -770,9 +771,9 @@ class _MindsporeFunctionExecutor:
|
|
|
770
771
|
self.fn.__code__.co_firstlineno)
|
|
771
772
|
echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
|
|
772
773
|
+ "\", line " + str(self.fn.__code__.co_firstlineno)
|
|
773
|
-
if _pynative_executor.
|
|
774
|
+
if _pynative_executor.requires_grad():
|
|
774
775
|
generate_name = generate_name + ".grad"
|
|
775
|
-
if
|
|
776
|
+
if self.fn.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
776
777
|
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
777
778
|
return generate_name, echo_function_name
|
|
778
779
|
|
|
@@ -833,6 +834,14 @@ class _MindsporeFunctionExecutor:
|
|
|
833
834
|
"""
|
|
834
835
|
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
835
836
|
|
|
837
|
+
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
838
|
+
"""Get graph proto from pipeline."""
|
|
839
|
+
if use_prefix:
|
|
840
|
+
exec_id = exec_id + '.' + obj.arguments_key
|
|
841
|
+
if self._graph_executor.has_compiled(exec_id) is False:
|
|
842
|
+
return None
|
|
843
|
+
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
844
|
+
|
|
836
845
|
|
|
837
846
|
# The attributes used to identify a given object.
|
|
838
847
|
attr_op = {"__str__": lambda x: x.__str__(),
|
|
@@ -845,6 +854,13 @@ attr_op = {"__str__": lambda x: x.__str__(),
|
|
|
845
854
|
}
|
|
846
855
|
|
|
847
856
|
|
|
857
|
+
def _is_inner_func(func):
|
|
858
|
+
"""Check whether the func is an inner func which needs hash_args parameter."""
|
|
859
|
+
# This is a workaround for inner api, should fix it later.
|
|
860
|
+
inner_func = ["after_shard", "_wrap_container"]
|
|
861
|
+
return func.__name__ in inner_func
|
|
862
|
+
|
|
863
|
+
|
|
848
864
|
def _get_obj_id(input_obj):
|
|
849
865
|
"""Get hash id of single object."""
|
|
850
866
|
obj_id = ".".join(
|
|
@@ -859,50 +875,227 @@ def _get_jit_hash(hash_input):
|
|
|
859
875
|
return _get_obj_id(hash_input)
|
|
860
876
|
|
|
861
877
|
|
|
862
|
-
def
|
|
878
|
+
def _get_hash_obj(options):
|
|
879
|
+
hash_obj = None
|
|
880
|
+
if "hash_args" in options:
|
|
881
|
+
hash_obj = _get_jit_hash(options["hash_args"])
|
|
882
|
+
del options["hash_args"]
|
|
883
|
+
return hash_obj
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _check_option_device(option, device):
|
|
887
|
+
"""Check jit options wiwh device"""
|
|
888
|
+
option_device_cfgs = {
|
|
889
|
+
'disable_format_transform': ['GPU'],
|
|
890
|
+
'exec_order': ['Ascend'],
|
|
891
|
+
'ge_options': ['Ascend'],
|
|
892
|
+
'infer_boost': ['Ascend'],
|
|
893
|
+
}
|
|
894
|
+
if option in option_device_cfgs and device not in option_device_cfgs[option]:
|
|
895
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support device in "
|
|
896
|
+
f"'{option_device_cfgs[option]}', but got '{device}', ignore it.")
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def _check_option_backend(option, backend):
|
|
900
|
+
"""Check jit options wiwh backend"""
|
|
901
|
+
option_backend_cfgs = {
|
|
902
|
+
'disable_format_transform': ['ms_backend'],
|
|
903
|
+
'exec_order': ['ms_backend'],
|
|
904
|
+
'ge_options': ['GE'],
|
|
905
|
+
'infer_boost': ['ms_backend'],
|
|
906
|
+
}
|
|
907
|
+
if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
|
|
908
|
+
logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
|
|
909
|
+
f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
def _check_disable_format_transform_value(option, disable_format_transform):
|
|
913
|
+
"""check disable_format_transform option value"""
|
|
914
|
+
if not isinstance(disable_format_transform, bool):
|
|
915
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be bool, "
|
|
916
|
+
f"but got {type(disable_format_transform)}.")
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _check_exec_order_value(option, exec_order):
|
|
920
|
+
"""check exec_order option value"""
|
|
921
|
+
if not isinstance(exec_order, str):
|
|
922
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(exec_order)}.")
|
|
923
|
+
|
|
924
|
+
if exec_order not in ['bfs', 'dfs']:
|
|
925
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of "
|
|
926
|
+
f"['bfs', 'dfs'], but got '{exec_order}'.")
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def _check_ge_options_value(option, ge_options):
|
|
930
|
+
"""check ge_options option value"""
|
|
931
|
+
if not isinstance(ge_options, dict):
|
|
932
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be dict, but got {type(ge_options)}.")
|
|
933
|
+
|
|
934
|
+
for level, options in ge_options.items():
|
|
935
|
+
if level not in ['global', 'session']:
|
|
936
|
+
raise ValueError(f"For '{option}', the key of '{option}' must be one of "
|
|
937
|
+
f"['global', 'session'], but got '{level}'.")
|
|
938
|
+
|
|
939
|
+
if not isinstance(options, dict):
|
|
940
|
+
raise TypeError(f"For '{option}', the type of {level} options must be dict, "
|
|
941
|
+
f"but got {type(options)}. The error options: {options}.")
|
|
942
|
+
|
|
943
|
+
for key, value in options.items():
|
|
944
|
+
if not isinstance(key, str):
|
|
945
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
946
|
+
f"but got {type(key)}. The error key is {key}.")
|
|
947
|
+
if not isinstance(value, str):
|
|
948
|
+
raise TypeError(f"For '{option}', the type of key and value must be str, "
|
|
949
|
+
f"but got {type(value)}. The error value is {value}")
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def _check_infer_boost_value(option, value):
|
|
953
|
+
"""check infer_boost option value"""
|
|
954
|
+
if not isinstance(value, str):
|
|
955
|
+
raise TypeError(f"For 'jit(options)', the type of '{option}' must be str, but got {type(value)}.")
|
|
956
|
+
|
|
957
|
+
if value not in ['on', 'off']:
|
|
958
|
+
raise ValueError(f"For '{option}', the value of '{option}' must be one of ['on', 'off'], but got '{value}'.")
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
def _check_option_value(option, value):
|
|
962
|
+
"""check jit options wiwh value"""
|
|
963
|
+
option_valuecheck_funcs = {
|
|
964
|
+
'disable_format_transform': _check_disable_format_transform_value,
|
|
965
|
+
'exec_order': _check_exec_order_value,
|
|
966
|
+
'ge_options': _check_ge_options_value,
|
|
967
|
+
'infer_boost': _check_infer_boost_value,
|
|
968
|
+
}
|
|
969
|
+
if option in option_valuecheck_funcs:
|
|
970
|
+
option_valuecheck_funcs[option](option, value)
|
|
971
|
+
else:
|
|
972
|
+
logger.warning(f"For 'jit(options)', the option argument '{option}' is not recognized, please check!"
|
|
973
|
+
f"For detailed usage of 'jit(options)', please refer to the Mindspore official website.")
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def _check_options(options, backend):
|
|
977
|
+
"""Check jit options"""
|
|
978
|
+
# check whether there are deprecated parameters in the dict `options`.
|
|
979
|
+
deprecated_args = {'mode': 'capture_mode', 'input_signature': 'dynamic', 'hash_args: ': '',
|
|
980
|
+
'jit_config': 'jit_level, fullgraph or options', 'compile_once': ''}
|
|
981
|
+
for key, value in deprecated_args.items():
|
|
982
|
+
if key in options:
|
|
983
|
+
log = f"For 'jit', the parameter '{key}' has been deprecated."
|
|
984
|
+
if value != '':
|
|
985
|
+
log += f" Please use the parameter '{value}' instead. For more details, please refer to " \
|
|
986
|
+
f"https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html."
|
|
987
|
+
logger.warning(log)
|
|
988
|
+
del options[key]
|
|
989
|
+
|
|
990
|
+
# check options' device, backend and value
|
|
991
|
+
for option, value in options.items():
|
|
992
|
+
_check_option_backend(option, backend)
|
|
993
|
+
_check_option_value(option, value)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def jit(
|
|
997
|
+
function: Optional[Callable] = None,
|
|
998
|
+
*,
|
|
999
|
+
capture_mode: str = "ast",
|
|
1000
|
+
jit_level: str = "O0",
|
|
1001
|
+
dynamic: int = 0,
|
|
1002
|
+
fullgraph: bool = False,
|
|
1003
|
+
backend: str = "",
|
|
1004
|
+
**options):
|
|
863
1005
|
"""
|
|
864
1006
|
Create a callable MindSpore graph from a Python function.
|
|
865
1007
|
|
|
866
1008
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
867
1009
|
|
|
868
1010
|
Note:
|
|
869
|
-
-
|
|
870
|
-
|
|
871
|
-
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
decorated with @jit(mode=“PIJit”) are not supported,
|
|
875
|
-
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
1011
|
+
- It is not supported to run a function with decoration @jit(capture_mode=“bytecode”)
|
|
1012
|
+
in static graph mode, in which case the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1013
|
+
- Calls to functions with decorated @jit(capture_mode=“bytecode”) inside functions
|
|
1014
|
+
decorated with @jit(capture_mode=“ast”) are not supported,
|
|
1015
|
+
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
876
1016
|
|
|
877
1017
|
Args:
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
1018
|
+
function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
|
|
1019
|
+
|
|
1020
|
+
Keyword Args:
|
|
1021
|
+
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1022
|
+
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1023
|
+
|
|
1024
|
+
- `ast <https://www.mindspore.cn/tutorials/en/master/compile/static_graph.html>`_ :
|
|
1025
|
+
Parse Python ast to build graph.
|
|
1026
|
+
- `bytecode` :
|
|
1027
|
+
Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
|
|
1028
|
+
change and/or deletion.
|
|
1029
|
+
- `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1030
|
+
subject to change and/or deletion.
|
|
1031
|
+
|
|
1032
|
+
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1033
|
+
with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1034
|
+
|
|
1035
|
+
- `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1036
|
+
- `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1037
|
+
level is experimental and is being improved.
|
|
1038
|
+
|
|
1039
|
+
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1040
|
+
is as follows:
|
|
1041
|
+
|
|
1042
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
1043
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1044
|
+
|
|
1045
|
+
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1046
|
+
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1047
|
+
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1048
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
|
|
1049
|
+
Default: ``False``.
|
|
1050
|
+
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1051
|
+
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
|
|
1052
|
+
A2 training series products by default.
|
|
1053
|
+
|
|
1054
|
+
- `ms_backend`: Adopt KernelByKernel execution mode.
|
|
1055
|
+
- `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
|
|
1056
|
+
the top cell of model. And only can be used in Ascend platform.
|
|
1057
|
+
|
|
1058
|
+
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1059
|
+
|
|
1060
|
+
Some options are device specific, see the below table for details:
|
|
1061
|
+
|
|
1062
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1063
|
+
| Option Parameters | Hardware Platform Support | Backend Support |
|
|
1064
|
+
+===========================+===========================+=========================+
|
|
1065
|
+
| disable_format_transform | GPU | ms_backend |
|
|
1066
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1067
|
+
| exec_order | Ascend | ms_backend |
|
|
1068
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1069
|
+
| ge_options | Ascend | GE |
|
|
1070
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1071
|
+
| infer_boost | Ascend | ms_backend |
|
|
1072
|
+
+---------------------------+---------------------------+-------------------------+
|
|
1073
|
+
|
|
1074
|
+
- disable_format_transform (bool, optional): Whether to disable the automatic format transform function
|
|
1075
|
+
from NCHW to NHWC. When the network training performance of fp16 is worse than fp32,
|
|
1076
|
+
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1077
|
+
Default: ``False`` .
|
|
1078
|
+
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1079
|
+
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1080
|
+
|
|
1081
|
+
- `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1082
|
+
performance.
|
|
1083
|
+
- `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1084
|
+
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1085
|
+
other execution orders run out of memory (OOM).
|
|
1086
|
+
|
|
1087
|
+
- ge_options (dict): Set options for ge backend. The options are divided into two categories: global,
|
|
1088
|
+
and session. This is an experimental prototype that is subject to change and/or deletion.
|
|
1089
|
+
For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/ascendgraphapi/atlasgeapi_07_0146.html>`_ .
|
|
1090
|
+
|
|
1091
|
+
- global (dict): Set global options.
|
|
1092
|
+
- session (dict): Set session options.
|
|
1093
|
+
|
|
1094
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1095
|
+
the inference mode is disabled. The range is as follows:
|
|
1096
|
+
|
|
1097
|
+
- `on`: Enable inference mode, get better infer performance.
|
|
1098
|
+
- `off`: Disable inference mode, use forward for inference. The performance is poor.
|
|
906
1099
|
|
|
907
1100
|
Returns:
|
|
908
1101
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -921,12 +1114,12 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
921
1114
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
922
1115
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
923
1116
|
...
|
|
924
|
-
>>> # create a callable MindSpore graph by calling
|
|
1117
|
+
>>> # create a callable MindSpore graph by calling jit
|
|
925
1118
|
>>> def tensor_add(x, y):
|
|
926
1119
|
... z = x + y
|
|
927
1120
|
... return z
|
|
928
1121
|
...
|
|
929
|
-
>>> tensor_add_graph = jit(
|
|
1122
|
+
>>> tensor_add_graph = jit(function=tensor_add)
|
|
930
1123
|
>>> out = tensor_add_graph(x, y)
|
|
931
1124
|
...
|
|
932
1125
|
>>> # create a callable MindSpore graph through decorator @jit
|
|
@@ -937,180 +1130,70 @@ def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=
|
|
|
937
1130
|
...
|
|
938
1131
|
>>> out = tensor_add_with_dec(x, y)
|
|
939
1132
|
...
|
|
940
|
-
>>> # create a callable MindSpore graph
|
|
941
|
-
>>> @jit(
|
|
942
|
-
...
|
|
943
|
-
... def tensor_add_with_sig(x, y):
|
|
944
|
-
... z = x + y
|
|
945
|
-
... return z
|
|
946
|
-
...
|
|
947
|
-
>>> out = tensor_add_with_sig(x, y)
|
|
948
|
-
...
|
|
949
|
-
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
|
|
950
|
-
... def tensor_add_with_sig_1(x, y):
|
|
1133
|
+
>>> # create a callable MindSpore graph and capture the entire function into the graph
|
|
1134
|
+
>>> @jit(fullgraph=True)
|
|
1135
|
+
... def tensor_add_fullgraph(x, y):
|
|
951
1136
|
... z = x + y
|
|
952
1137
|
... return z
|
|
953
1138
|
...
|
|
954
|
-
>>>
|
|
955
|
-
...
|
|
956
|
-
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
|
|
957
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
958
|
-
>>> def func(x):
|
|
959
|
-
... return ops.exp(x)
|
|
960
|
-
...
|
|
961
|
-
>>> def closure_fn(x, fn):
|
|
962
|
-
... @jit(hash_args=fn)
|
|
963
|
-
... def inner_fn(a):
|
|
964
|
-
... return fn(a)
|
|
965
|
-
... return inner_fn(x)
|
|
966
|
-
...
|
|
967
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
968
|
-
>>> for i in range(10):
|
|
969
|
-
... closure_fn(inputs, func)
|
|
970
|
-
...
|
|
971
|
-
... # Set compile_once = True, otherwise the train_step will be compiled again.
|
|
972
|
-
>>> def train(x):
|
|
973
|
-
... @jit(compile_once = True)
|
|
974
|
-
... def train_step(x):
|
|
975
|
-
... return ops.exp(x)
|
|
976
|
-
... for i in range(10):
|
|
977
|
-
... train_step(x)
|
|
978
|
-
...
|
|
979
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
980
|
-
>>> for i in range(10):
|
|
981
|
-
... train(inputs)
|
|
1139
|
+
>>> out = tensor_add_fullgraph(x, y)
|
|
982
1140
|
"""
|
|
983
1141
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
1142
|
+
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
1143
|
+
jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
|
|
1144
|
+
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
1145
|
+
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
1146
|
+
if backend == "":
|
|
1147
|
+
backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
|
|
1148
|
+
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1149
|
+
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
1150
|
+
hash_obj = _get_hash_obj(options)
|
|
1151
|
+
_check_options(options, backend)
|
|
1152
|
+
options_str = json.dumps(options)
|
|
1153
|
+
infer_boost = options['infer_boost'] if 'infer_boost' in options else "off"
|
|
1154
|
+
exc_mode = options['exc_mode'] if 'exc_mode' in options else "auto"
|
|
1155
|
+
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1156
|
+
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1157
|
+
|
|
1158
|
+
def wrap_func(func):
|
|
1159
|
+
nonlocal hash_obj
|
|
1160
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
993
1161
|
hash_obj = int(time.time() * 1e9)
|
|
994
1162
|
|
|
995
|
-
dyn_args = _process_dyn_args(func, input_signature)
|
|
996
|
-
|
|
997
1163
|
@wraps(func)
|
|
998
1164
|
def staging_specialize(*args, **kwargs):
|
|
999
1165
|
if os.getenv("MS_JIT") == '0':
|
|
1000
1166
|
return func(*args, **kwargs)
|
|
1001
1167
|
|
|
1002
1168
|
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1003
|
-
|
|
1004
1169
|
process_obj = None
|
|
1005
1170
|
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1006
1171
|
process_obj = args[0]
|
|
1007
|
-
# only the function or cell instance wrapped by shard will fall into this branch
|
|
1008
|
-
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
1009
|
-
process_obj = hash_args
|
|
1010
1172
|
# Handle auto mixed precision strategy.
|
|
1011
1173
|
if not hasattr(func, "amp_strategy"):
|
|
1012
1174
|
if isinstance(func, types.MethodType):
|
|
1013
1175
|
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1014
1176
|
else:
|
|
1015
1177
|
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1016
|
-
|
|
1178
|
+
|
|
1179
|
+
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1180
|
+
out = ms_function_executor(*args, **kwargs)
|
|
1017
1181
|
return out
|
|
1018
1182
|
|
|
1019
1183
|
return staging_specialize
|
|
1020
1184
|
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1185
|
+
if capture_mode == "bytecode":
|
|
1186
|
+
wrap_func = PIJitCaptureContext(jit_config)
|
|
1187
|
+
elif capture_mode == "trace":
|
|
1188
|
+
if function is not None:
|
|
1189
|
+
return _jit_trace(function)
|
|
1190
|
+
return _jit_trace
|
|
1024
1191
|
|
|
1025
|
-
if
|
|
1026
|
-
return wrap_func(
|
|
1192
|
+
if function is not None:
|
|
1193
|
+
return wrap_func(function)
|
|
1027
1194
|
return wrap_func
|
|
1028
1195
|
|
|
1029
1196
|
|
|
1030
|
-
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
1031
|
-
"""
|
|
1032
|
-
Create a callable MindSpore graph from a Python function.
|
|
1033
|
-
|
|
1034
|
-
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
1035
|
-
|
|
1036
|
-
Note:
|
|
1037
|
-
- `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
|
|
1038
|
-
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
1039
|
-
will not accept `**kwargs`.
|
|
1040
|
-
|
|
1041
|
-
Args:
|
|
1042
|
-
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
1043
|
-
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
|
1044
|
-
will be supplied to this function. The shape and dtype of actual inputs of `fn` should
|
|
1045
|
-
keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
|
|
1046
|
-
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
|
1047
|
-
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
1048
|
-
will trigger recompilation. Default: ``None`` .
|
|
1049
|
-
jit_config (JitConfig): Jit config for compile. Default: ``None`` .
|
|
1050
|
-
|
|
1051
|
-
Returns:
|
|
1052
|
-
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
1053
|
-
None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
|
|
1054
|
-
equal to the case when `fn` is not None.
|
|
1055
|
-
|
|
1056
|
-
Supported Platforms:
|
|
1057
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1058
|
-
|
|
1059
|
-
Examples:
|
|
1060
|
-
>>> import numpy as np
|
|
1061
|
-
>>> from mindspore import Tensor
|
|
1062
|
-
>>> from mindspore import ops
|
|
1063
|
-
>>> from mindspore import ms_function
|
|
1064
|
-
...
|
|
1065
|
-
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1066
|
-
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1067
|
-
...
|
|
1068
|
-
>>> # create a callable MindSpore graph by calling ms_function
|
|
1069
|
-
>>> def tensor_add(x, y):
|
|
1070
|
-
... z = x + y
|
|
1071
|
-
... return z
|
|
1072
|
-
...
|
|
1073
|
-
>>> tensor_add_graph = ms_function(fn=tensor_add)
|
|
1074
|
-
>>> out = tensor_add_graph(x, y)
|
|
1075
|
-
...
|
|
1076
|
-
>>> # create a callable MindSpore graph through decorator @ms_function
|
|
1077
|
-
>>> @ms_function
|
|
1078
|
-
... def tensor_add_with_dec(x, y):
|
|
1079
|
-
... z = x + y
|
|
1080
|
-
... return z
|
|
1081
|
-
...
|
|
1082
|
-
>>> out = tensor_add_with_dec(x, y)
|
|
1083
|
-
...
|
|
1084
|
-
>>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
|
|
1085
|
-
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
|
1086
|
-
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
|
1087
|
-
... def tensor_add_with_sig(x, y):
|
|
1088
|
-
... z = x + y
|
|
1089
|
-
... return z
|
|
1090
|
-
...
|
|
1091
|
-
>>> out = tensor_add_with_sig(x, y)
|
|
1092
|
-
...
|
|
1093
|
-
... # Set hash_args as fn, otherwise cache of compiled `closure_fn` will not be reused.
|
|
1094
|
-
... # While fn differs during calling again, recompilation will be triggered.
|
|
1095
|
-
>>> def func(x):
|
|
1096
|
-
... return ops.exp(x)
|
|
1097
|
-
...
|
|
1098
|
-
>>> def closure_fn(x, fn):
|
|
1099
|
-
... @ms_function(hash_args=fn)
|
|
1100
|
-
... def inner_fn(a):
|
|
1101
|
-
... return fn(a)
|
|
1102
|
-
... return inner_fn(x)
|
|
1103
|
-
...
|
|
1104
|
-
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
1105
|
-
>>> for i in range(10):
|
|
1106
|
-
... closure_fn(inputs, func)
|
|
1107
|
-
"""
|
|
1108
|
-
|
|
1109
|
-
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
|
1110
|
-
"Please use 'mindspore.jit' instead.")
|
|
1111
|
-
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
1197
|
def _core(fn=None, **flags):
|
|
1115
1198
|
"""
|
|
1116
1199
|
A decorator that adds a flag to the function.
|
|
@@ -1203,69 +1286,6 @@ def _no_recursive(callable_obj):
|
|
|
1203
1286
|
return callable_obj
|
|
1204
1287
|
|
|
1205
1288
|
|
|
1206
|
-
def ms_class(cls):
|
|
1207
|
-
"""
|
|
1208
|
-
Class decorator for user-defined classes.
|
|
1209
|
-
|
|
1210
|
-
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
|
1211
|
-
|
|
1212
|
-
Note:
|
|
1213
|
-
`ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
|
|
1214
|
-
|
|
1215
|
-
Args:
|
|
1216
|
-
cls (Class): User-defined class.
|
|
1217
|
-
|
|
1218
|
-
Returns:
|
|
1219
|
-
Class.
|
|
1220
|
-
|
|
1221
|
-
Raises:
|
|
1222
|
-
TypeError: If ms_class is used for non-class types or nn.Cell.
|
|
1223
|
-
AttributeError: If the private attributes or magic methods of the class decorated with ms_class is called.
|
|
1224
|
-
|
|
1225
|
-
Supported Platforms:
|
|
1226
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1227
|
-
|
|
1228
|
-
Examples:
|
|
1229
|
-
>>> import mindspore.nn as nn
|
|
1230
|
-
>>> from mindspore import ms_class
|
|
1231
|
-
...
|
|
1232
|
-
>>> @ms_class
|
|
1233
|
-
... class UserDefinedNet:
|
|
1234
|
-
... def __init__(self):
|
|
1235
|
-
... self.value = 10
|
|
1236
|
-
...
|
|
1237
|
-
... def func(self, x):
|
|
1238
|
-
... return 2 * x
|
|
1239
|
-
...
|
|
1240
|
-
>>> class Net(nn.Cell):
|
|
1241
|
-
... def __init__(self):
|
|
1242
|
-
... super(Net, self).__init__()
|
|
1243
|
-
... self.net = UserDefinedNet()
|
|
1244
|
-
...
|
|
1245
|
-
... def construct(self, x):
|
|
1246
|
-
... out = self.net.value + self.net.func(x)
|
|
1247
|
-
... return out
|
|
1248
|
-
...
|
|
1249
|
-
>>> net = Net()
|
|
1250
|
-
>>> out = net(5)
|
|
1251
|
-
>>> print(out)
|
|
1252
|
-
20
|
|
1253
|
-
"""
|
|
1254
|
-
|
|
1255
|
-
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
|
1256
|
-
"Please use 'mindspore.jit_class' instead.")
|
|
1257
|
-
|
|
1258
|
-
# Check if cls is of type class.
|
|
1259
|
-
if not inspect.isclass(cls):
|
|
1260
|
-
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
|
1261
|
-
# Check if cls is nn.Cell.
|
|
1262
|
-
if issubclass(cls, ms.nn.Cell):
|
|
1263
|
-
raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1264
|
-
logger.info(f'Found ms_class: {cls}.')
|
|
1265
|
-
setattr(cls, '__ms_class__', True)
|
|
1266
|
-
return cls
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
1289
|
def jit_class(cls):
|
|
1270
1290
|
"""
|
|
1271
1291
|
Class decorator for user-defined classes.
|
|
@@ -1322,28 +1342,6 @@ def jit_class(cls):
|
|
|
1322
1342
|
return cls
|
|
1323
1343
|
|
|
1324
1344
|
|
|
1325
|
-
def set_adapter_config(config):
|
|
1326
|
-
"""
|
|
1327
|
-
Register configuration information for MSAdapter.
|
|
1328
|
-
|
|
1329
|
-
Args:
|
|
1330
|
-
config (dict): Configuration information.
|
|
1331
|
-
"""
|
|
1332
|
-
if not isinstance(config, dict):
|
|
1333
|
-
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
|
|
1334
|
-
for key, value in config.items():
|
|
1335
|
-
if key == "Tensor":
|
|
1336
|
-
ms_adapter_registry.register_tensor(value)
|
|
1337
|
-
elif key == "Parameter":
|
|
1338
|
-
ms_adapter_registry.register_parameter(value)
|
|
1339
|
-
elif key == "convert_object_map":
|
|
1340
|
-
ms_adapter_registry.register_convert_map(value)
|
|
1341
|
-
elif key == "convert_adapter_tensor_map":
|
|
1342
|
-
ms_adapter_registry.register_convert_adapter_tensor_map(value)
|
|
1343
|
-
else:
|
|
1344
|
-
raise ValueError(f"Unsupported key in adapter config: {key}")
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
1345
|
def _function_forbid_reuse(func):
|
|
1348
1346
|
if not inspect.isfunction(func):
|
|
1349
1347
|
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
|
@@ -1535,7 +1533,24 @@ class _PyNativeExecutor:
|
|
|
1535
1533
|
Return:
|
|
1536
1534
|
None.
|
|
1537
1535
|
"""
|
|
1538
|
-
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1536
|
+
return self._executor.grad(grad, obj, weights, grad_position, False, *args)
|
|
1537
|
+
|
|
1538
|
+
def grad_aux(self, obj, grad, weights, grad_position, *args):
|
|
1539
|
+
"""
|
|
1540
|
+
Run grad graph with aux
|
|
1541
|
+
|
|
1542
|
+
Args:
|
|
1543
|
+
obj (Function/Cell): The function or cell instance.
|
|
1544
|
+
grad (GradOperation): The gradoperation object.
|
|
1545
|
+
weights (ParameterTuple): The weights of cell instance.
|
|
1546
|
+
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1547
|
+
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1548
|
+
args (tuple): Function or cell input arguments.
|
|
1549
|
+
|
|
1550
|
+
Return:
|
|
1551
|
+
None.
|
|
1552
|
+
"""
|
|
1553
|
+
return self._executor.grad(grad, obj, weights, grad_position, True, *args)
|
|
1539
1554
|
|
|
1540
1555
|
def clear_res(self):
|
|
1541
1556
|
"""
|
|
@@ -1671,6 +1686,15 @@ class _PyNativeExecutor:
|
|
|
1671
1686
|
"""
|
|
1672
1687
|
self._executor.set_is_run_recompute(status)
|
|
1673
1688
|
|
|
1689
|
+
def high_order(self):
|
|
1690
|
+
"""
|
|
1691
|
+
Is high order of current scene, this is a inner interface.
|
|
1692
|
+
|
|
1693
|
+
Return:
|
|
1694
|
+
Bool.
|
|
1695
|
+
"""
|
|
1696
|
+
return self._executor.high_order()
|
|
1697
|
+
|
|
1674
1698
|
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1675
1699
|
"""
|
|
1676
1700
|
Set the dynamic shape flag of eval process.
|
|
@@ -1753,7 +1777,6 @@ class _CellGraphExecutor:
|
|
|
1753
1777
|
# create needed graph by lazy mode
|
|
1754
1778
|
self.is_init = False
|
|
1755
1779
|
self.enable_tuple_broaden = False
|
|
1756
|
-
self.obfuscate_config = None # used for model's dynamic obfuscation
|
|
1757
1780
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1758
1781
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1759
1782
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
@@ -1845,6 +1868,7 @@ class _CellGraphExecutor:
|
|
|
1845
1868
|
Str, the full phase of the cell.
|
|
1846
1869
|
Bool, if the graph has been compiled before, return False, else return True.
|
|
1847
1870
|
"""
|
|
1871
|
+
_init_auto_parallel_context(obj)
|
|
1848
1872
|
obj.__parse_method__ = 'construct'
|
|
1849
1873
|
if not hasattr(obj, obj.__parse_method__):
|
|
1850
1874
|
raise AttributeError(
|
|
@@ -1877,6 +1901,7 @@ class _CellGraphExecutor:
|
|
|
1877
1901
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1878
1902
|
# generated in generate_arguments_key.
|
|
1879
1903
|
self._graph_executor.clear_compile_arguments_resource()
|
|
1904
|
+
_clear_auto_parallel_context(obj)
|
|
1880
1905
|
return phase, False
|
|
1881
1906
|
|
|
1882
1907
|
full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
|
|
@@ -1894,7 +1919,8 @@ class _CellGraphExecutor:
|
|
|
1894
1919
|
else:
|
|
1895
1920
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1896
1921
|
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1897
|
-
|
|
1922
|
+
gc.collect()
|
|
1923
|
+
result = self._graph_executor.compile(obj, args, kwargs, phase)
|
|
1898
1924
|
obj.compile_cache.add(phase)
|
|
1899
1925
|
if not result:
|
|
1900
1926
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -1915,6 +1941,7 @@ class _CellGraphExecutor:
|
|
|
1915
1941
|
self._build_data_graph(obj, phase)
|
|
1916
1942
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
1917
1943
|
_parameter_broadcast(obj)
|
|
1944
|
+
_clear_auto_parallel_context(obj)
|
|
1918
1945
|
return phase, True
|
|
1919
1946
|
|
|
1920
1947
|
def _update_param_node_default_input(self, phase, replace):
|
|
@@ -1994,25 +2021,12 @@ class _CellGraphExecutor:
|
|
|
1994
2021
|
"""Clear the memory resource of a network."""
|
|
1995
2022
|
self._graph_executor.del_net_res(obj, net_id)
|
|
1996
2023
|
|
|
1997
|
-
def _get_branch_control_input(self):
|
|
1998
|
-
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
|
1999
|
-
'obf_random_seed' not in self.obfuscate_config.keys()):
|
|
2000
|
-
raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
|
|
2001
|
-
obf_random_seed = self.obfuscate_config.get('obf_random_seed')
|
|
2002
|
-
if obf_random_seed == 0:
|
|
2003
|
-
branch_control_input = 0
|
|
2004
|
-
else:
|
|
2005
|
-
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
2006
|
-
return branch_control_input
|
|
2007
|
-
|
|
2008
2024
|
def _get_func_graph(self, obj, exec_id, use_prefix=False):
|
|
2009
2025
|
"""Get func graph from pipeline."""
|
|
2010
2026
|
if use_prefix:
|
|
2011
2027
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
2012
2028
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2013
2029
|
return None
|
|
2014
|
-
if self.obfuscate_config is not None:
|
|
2015
|
-
raise ValueError('For get func graph, obfuscate_config is currently not supported now.')
|
|
2016
2030
|
return self._graph_executor.get_func_graph(exec_id)
|
|
2017
2031
|
|
|
2018
2032
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
@@ -2021,11 +2035,6 @@ class _CellGraphExecutor:
|
|
|
2021
2035
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
2022
2036
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2023
2037
|
return None
|
|
2024
|
-
if self.obfuscate_config is not None:
|
|
2025
|
-
branch_control_input = self._get_branch_control_input()
|
|
2026
|
-
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
|
|
2027
|
-
self.obfuscate_config['obf_ratio'],
|
|
2028
|
-
branch_control_input)
|
|
2029
2038
|
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
2030
2039
|
|
|
2031
2040
|
def get_optimize_graph_proto(self, obj):
|
|
@@ -2063,6 +2072,8 @@ def ms_memory_recycle():
|
|
|
2063
2072
|
"""
|
|
2064
2073
|
if ms_compile_cache:
|
|
2065
2074
|
_cell_graph_executor.del_net_res(None, ms_compile_cache)
|
|
2075
|
+
if os.getenv('MS_DEV_JIT_PIPELINE') != '0':
|
|
2076
|
+
JitExecutor_.get_instance().del_net_res(None, ms_compile_cache)
|
|
2066
2077
|
ms_compile_cache.clear()
|
|
2067
2078
|
for cell_cache in cells_compile_cache.values():
|
|
2068
2079
|
if cell_cache:
|
|
@@ -2089,30 +2100,6 @@ def set_recursion_limit(recursion_limit=1000):
|
|
|
2089
2100
|
GraphExecutor_.get_instance().set_max_call_depth(recursion_limit)
|
|
2090
2101
|
|
|
2091
2102
|
|
|
2092
|
-
def _generate_branch_control_input(obf_random_seed):
|
|
2093
|
-
"""Generate append network input for dynamic obfuscation in random seed mode."""
|
|
2094
|
-
seed_max = 2 ** 32 - 1
|
|
2095
|
-
int_max = 2 ** 31 - 1
|
|
2096
|
-
np.random.seed(obf_random_seed % seed_max)
|
|
2097
|
-
# generate a string as hash function inputs
|
|
2098
|
-
word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
|
|
2099
|
-
repo_len = len(word_repo)
|
|
2100
|
-
sha_string = ''
|
|
2101
|
-
string_len = 1024 * 1024
|
|
2102
|
-
for _ in range(string_len):
|
|
2103
|
-
rand_index = np.random.randint(0, repo_len)
|
|
2104
|
-
sha_string += word_repo[rand_index]
|
|
2105
|
-
# get hash result
|
|
2106
|
-
sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
|
|
2107
|
-
branch_control_input = 1
|
|
2108
|
-
hex_base = 16
|
|
2109
|
-
for item in sha_result:
|
|
2110
|
-
if int(item, hex_base) > 0:
|
|
2111
|
-
branch_control_input *= int(item, hex_base)
|
|
2112
|
-
branch_control_input %= int_max
|
|
2113
|
-
return branch_control_input
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
2103
|
def _bind_device_context():
|
|
2117
2104
|
"""Bind device context to current thread"""
|
|
2118
2105
|
_bind_device_ctx()
|
|
@@ -2135,4 +2122,4 @@ def flops_collection(phase='train'):
|
|
|
2135
2122
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2136
2123
|
_pynative_executor = _PyNativeExecutor()
|
|
2137
2124
|
|
|
2138
|
-
__all__ = ['
|
|
2125
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|