mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -44,18 +44,20 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
|
44
44
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
45
45
|
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
46
46
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
47
|
-
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx,
|
|
47
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, TensorPy as Tensor, dump_func_graph, _GraphFragment_
|
|
48
48
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
49
49
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
50
50
|
_is_parallel_mode
|
|
51
51
|
from mindspore import _checkparam as Validator
|
|
52
52
|
from mindspore._checkparam import is_stub_tensor
|
|
53
|
-
from mindspore.common._utils import is_shape_unknown
|
|
53
|
+
from mindspore.common._utils import is_shape_unknown, get_func
|
|
54
54
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
55
|
-
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args,
|
|
56
|
-
|
|
55
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
|
|
56
|
+
update_auto_dynamic_shape_phase
|
|
57
|
+
from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
|
|
57
58
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
58
|
-
from mindspore.common.parameter import Parameter
|
|
59
|
+
from mindspore.common.parameter import Parameter
|
|
60
|
+
from mindspore.common.hook_handle import _hook_version
|
|
59
61
|
from mindspore.common.jit_context import jit_context
|
|
60
62
|
from mindspore.common.jit_trace import _jit_trace
|
|
61
63
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
@@ -74,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
|
|
|
74
76
|
TOTAL_ARG_LEN = "total_arg_length"
|
|
75
77
|
|
|
76
78
|
|
|
79
|
+
def _real_phase(phase, obj):
|
|
80
|
+
real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
81
|
+
return real_phase
|
|
82
|
+
|
|
83
|
+
|
|
77
84
|
def _check_recompile_args(compile_args, kwargs):
|
|
78
85
|
"""Check recompile of graph"""
|
|
79
86
|
|
|
@@ -201,6 +208,11 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
201
208
|
args = bound_arguments.args
|
|
202
209
|
kwargs = bound_arguments.kwargs
|
|
203
210
|
|
|
211
|
+
return args, kwargs
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _check_func_args(func, *args):
|
|
215
|
+
"""Check the *args inputs of the function"""
|
|
204
216
|
positional_args = 0
|
|
205
217
|
default_args = 0
|
|
206
218
|
has_var = False
|
|
@@ -214,14 +226,13 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
214
226
|
default_args += 1
|
|
215
227
|
|
|
216
228
|
if has_var:
|
|
217
|
-
return
|
|
229
|
+
return
|
|
218
230
|
|
|
219
231
|
if len(args) < positional_args:
|
|
220
232
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
|
|
221
233
|
if len(args) > positional_args + default_args:
|
|
222
234
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
|
|
223
235
|
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
|
|
224
|
-
return args, kwargs
|
|
225
236
|
|
|
226
237
|
|
|
227
238
|
sys_path = list(sys.path)
|
|
@@ -342,7 +353,7 @@ def _get_parameter_layout():
|
|
|
342
353
|
return layout
|
|
343
354
|
|
|
344
355
|
|
|
345
|
-
def _handle_arg(obj, arg, has_mutable_arg):
|
|
356
|
+
def _handle_arg(obj, arg, has_mutable_arg, is_predict):
|
|
346
357
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
347
358
|
from mindspore._extends.parse import compile_config
|
|
348
359
|
if isinstance(arg, PythonTensor):
|
|
@@ -357,7 +368,7 @@ def _handle_arg(obj, arg, has_mutable_arg):
|
|
|
357
368
|
if isinstance(arg, list) and not arg:
|
|
358
369
|
return None
|
|
359
370
|
return arg
|
|
360
|
-
elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
|
|
371
|
+
elif not is_predict and (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
|
|
361
372
|
isinstance(arg, (int, float)):
|
|
362
373
|
return arg
|
|
363
374
|
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
|
@@ -387,17 +398,16 @@ def _handle_arg_predict(obj, arg, has_mutable_arg):
|
|
|
387
398
|
return arg
|
|
388
399
|
|
|
389
400
|
|
|
390
|
-
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
|
|
401
|
+
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict=False):
|
|
391
402
|
"""Get the actual input args and kwargs for runtime."""
|
|
392
403
|
new_args = []
|
|
393
|
-
fn = _handle_arg_predict if is_predict else _handle_arg
|
|
394
404
|
for arg, has_mutable_arg in zip(args, has_mutable_args_list):
|
|
395
|
-
new_arg =
|
|
405
|
+
new_arg = _handle_arg(obj, arg, has_mutable_arg, is_predict)
|
|
396
406
|
if new_arg is not None:
|
|
397
407
|
new_args.append(new_arg)
|
|
398
408
|
|
|
399
409
|
for _, value in kwargs.items():
|
|
400
|
-
new_value =
|
|
410
|
+
new_value = _handle_arg(obj, value, None, is_predict)
|
|
401
411
|
if new_value is not None:
|
|
402
412
|
new_args.append(new_value)
|
|
403
413
|
|
|
@@ -538,10 +548,12 @@ def _get_parameter_ids(args, kwargs):
|
|
|
538
548
|
parameter_ids += str(id(value))
|
|
539
549
|
return parameter_ids
|
|
540
550
|
|
|
551
|
+
|
|
541
552
|
def _get_tensor_hook_key(tensor):
|
|
542
553
|
"""Get the hook key of Tensor/Parameter"""
|
|
543
554
|
return ".".join(map(str, map(id, tensor.hooks())))
|
|
544
555
|
|
|
556
|
+
|
|
545
557
|
def _get_hook_key(*args, **kwargs):
|
|
546
558
|
"""Get the hook key of Tensors/Parameters"""
|
|
547
559
|
hook_key = ""
|
|
@@ -588,6 +600,8 @@ class _JitExecutor:
|
|
|
588
600
|
|
|
589
601
|
self.fn = fn
|
|
590
602
|
self.input_signature = input_signature
|
|
603
|
+
self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
|
|
604
|
+
self.enable_jit_dynamic = self.dynamic_args_shapes is not None
|
|
591
605
|
self.obj = None
|
|
592
606
|
if obj and hasattr(obj, fn.__name__):
|
|
593
607
|
self.obj = obj
|
|
@@ -598,7 +612,7 @@ class _JitExecutor:
|
|
|
598
612
|
else:
|
|
599
613
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
600
614
|
self._create_time = ms_create_time
|
|
601
|
-
self.
|
|
615
|
+
self._mutable_flags = None
|
|
602
616
|
self._enable_auto_dynamic = dynamic == 1
|
|
603
617
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
604
618
|
self._cell_cache_key_extend = cell_cache_key_extend
|
|
@@ -623,18 +637,8 @@ class _JitExecutor:
|
|
|
623
637
|
except Exception as err:
|
|
624
638
|
_pynative_executor.clear_res()
|
|
625
639
|
raise err
|
|
626
|
-
else: # get compiled args to generate run args by _generate_run_args
|
|
627
|
-
compile_args = self._generate_compile_args(args_list)
|
|
628
|
-
key_id = self._get_key_id()
|
|
629
|
-
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
|
|
630
|
-
compile_args,
|
|
631
|
-
key_id,
|
|
632
|
-
self.input_signature,
|
|
633
|
-
self._enable_auto_dynamic
|
|
634
|
-
)
|
|
635
|
-
self._compile_args = compile_args
|
|
636
640
|
|
|
637
|
-
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
641
|
+
new_inputs = self._generate_run_args(args_list, kwargs, is_predict=True)
|
|
638
642
|
if self.jit_config_dict:
|
|
639
643
|
jit_config_dict = self.jit_config_dict
|
|
640
644
|
else:
|
|
@@ -647,11 +651,25 @@ class _JitExecutor:
|
|
|
647
651
|
res = _convert_python_data(output)
|
|
648
652
|
return True, res
|
|
649
653
|
|
|
654
|
+
def compile_frontend(self, *args, **kwargs):
|
|
655
|
+
"""Only compile to the frontend graph."""
|
|
656
|
+
args_list = args
|
|
657
|
+
if self.obj is not None:
|
|
658
|
+
args_list = args_list[1:]
|
|
659
|
+
os.environ['MS_DEV_PRECOMPILE_ONLY'] = '1'
|
|
660
|
+
phase = ""
|
|
661
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
662
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
663
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
664
|
+
os.unsetenv('MS_DEV_PRECOMPILE_ONLY')
|
|
665
|
+
return self._graph_executor.get_func_graph(phase), self._mutable_flags, phase, self.enable_tuple_broaden
|
|
666
|
+
|
|
650
667
|
@_wrap_func
|
|
651
668
|
def __call__(self, *args, **kwargs):
|
|
652
669
|
predict, res = self._predict(*args, **kwargs)
|
|
653
670
|
if predict:
|
|
654
671
|
return res
|
|
672
|
+
_check_func_args(self.fn, *args)
|
|
655
673
|
if jit_context() and jit_context().is_nested():
|
|
656
674
|
return jit_context().run_graph("", None, *())
|
|
657
675
|
args_list = args
|
|
@@ -659,9 +677,9 @@ class _JitExecutor:
|
|
|
659
677
|
args_list = args_list[1:]
|
|
660
678
|
phase = ""
|
|
661
679
|
try:
|
|
662
|
-
_pynative_executor.
|
|
680
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
663
681
|
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
664
|
-
_pynative_executor.
|
|
682
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
665
683
|
except Exception as err:
|
|
666
684
|
_pynative_executor.clear_res()
|
|
667
685
|
raise err
|
|
@@ -684,24 +702,24 @@ class _JitExecutor:
|
|
|
684
702
|
|
|
685
703
|
def compile(self, method_name, *args, **kwargs):
|
|
686
704
|
"""Returns pipeline for the given args."""
|
|
687
|
-
# Check whether hook function registered on Cell object.
|
|
688
|
-
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
|
689
|
-
if self.obj._hook_fn_registered():
|
|
690
|
-
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
|
|
691
|
-
f"If you want to use hook function, please use context.set_context to set "
|
|
692
|
-
f"pynative mode and remove 'jit' decorator.")
|
|
693
705
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
706
|
+
self._graph_executor.set_real_args(args, kwargs)
|
|
694
707
|
compile_args = self._generate_compile_args(args)
|
|
695
708
|
key_id = self._get_key_id()
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
709
|
+
if self.input_signature is None:
|
|
710
|
+
compile_args = get_auto_dynamic_shape_args(
|
|
711
|
+
compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
|
|
712
|
+
)
|
|
699
713
|
|
|
700
714
|
# Add mutable for compile_args for two scene:
|
|
701
715
|
# 1) Origin args is mutable.
|
|
702
716
|
# 2) Args contains sequence with gradient tensor.
|
|
703
717
|
compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
|
|
704
|
-
|
|
718
|
+
mutable_flags = _get_mutable_flags(compile_args)
|
|
719
|
+
self._mutable_flags = mutable_flags
|
|
720
|
+
# Store the _mutable_flags in the cell obj for incremental inference.
|
|
721
|
+
if self.obj is not None:
|
|
722
|
+
self.obj._mutable_flags = mutable_flags
|
|
705
723
|
generate_name, echo_function_name = self._get_generate_name()
|
|
706
724
|
# The full Function name
|
|
707
725
|
full_function_name = generate_name
|
|
@@ -735,20 +753,23 @@ class _JitExecutor:
|
|
|
735
753
|
|
|
736
754
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
737
755
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
756
|
+
key = str(key)
|
|
738
757
|
|
|
739
758
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
740
759
|
if parameter_ids != "":
|
|
741
|
-
key
|
|
760
|
+
key += '.' + parameter_ids
|
|
742
761
|
|
|
743
|
-
key
|
|
762
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
763
|
+
key += "." + str(_hook_version())
|
|
744
764
|
|
|
745
|
-
phase = generate_name + '.' +
|
|
765
|
+
phase = generate_name + '.' + key
|
|
746
766
|
|
|
747
|
-
|
|
767
|
+
if self.input_signature is None:
|
|
768
|
+
update_auto_dynamic_shape_phase(compile_args, key_id, phase)
|
|
748
769
|
|
|
749
770
|
phase = phase + self._cell_cache_key_extend
|
|
750
771
|
|
|
751
|
-
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase)
|
|
772
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
|
|
752
773
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
753
774
|
# generated in generate_arguments_key.
|
|
754
775
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -765,16 +786,9 @@ class _JitExecutor:
|
|
|
765
786
|
|
|
766
787
|
if self.obj is None:
|
|
767
788
|
# Set an attribute to fn as an identifier.
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
setattr(self.fn, "__jit_function__", True)
|
|
772
|
-
is_compile = self._graph_executor.compile(
|
|
773
|
-
self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
774
|
-
if isinstance(self.fn, types.MethodType):
|
|
775
|
-
delattr(self.fn.__func__, "__jit_function__")
|
|
776
|
-
else:
|
|
777
|
-
delattr(self.fn, "__jit_function__")
|
|
789
|
+
setattr(get_func(self.fn), "__jit_function__", True)
|
|
790
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
791
|
+
delattr(get_func(self.fn), "__jit_function__")
|
|
778
792
|
else:
|
|
779
793
|
if isinstance(self.obj, ms.nn.Cell):
|
|
780
794
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
@@ -783,7 +797,6 @@ class _JitExecutor:
|
|
|
783
797
|
|
|
784
798
|
if not is_compile:
|
|
785
799
|
raise RuntimeError("Executor compile failed.")
|
|
786
|
-
set_parameter_hook_updated(False)
|
|
787
800
|
ms_compile_cache.add(phase)
|
|
788
801
|
if hasattr(self.obj, "phase"):
|
|
789
802
|
self.obj.phase_cache[self.obj.phase] = phase
|
|
@@ -831,43 +844,73 @@ class _JitExecutor:
|
|
|
831
844
|
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
832
845
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
833
846
|
|
|
847
|
+
def _generate_compile_args_by_enable_dynamic(self, args_list):
|
|
848
|
+
"""Generate compile args by enable_dynamic."""
|
|
849
|
+
compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
|
|
850
|
+
compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
|
|
851
|
+
if self.obj is not None:
|
|
852
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
853
|
+
else:
|
|
854
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
855
|
+
logger.info(f"dynamic shape compile_args: {compile_args}")
|
|
856
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
857
|
+
return compile_args
|
|
858
|
+
|
|
859
|
+
def _generate_compile_args_by_set_inputs(self, args_list):
|
|
860
|
+
"""Generate compile args by set_inputs."""
|
|
861
|
+
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
862
|
+
if len(compile_args) != len(args_list):
|
|
863
|
+
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
864
|
+
f"dynamic shape tensors: {len(compile_args)}.")
|
|
865
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
|
|
866
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
867
|
+
return compile_args
|
|
868
|
+
|
|
869
|
+
def _generate_compile_args_by_input_signature(self, args_list):
|
|
870
|
+
"""Generate compile args by input_signature."""
|
|
871
|
+
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
872
|
+
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
873
|
+
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
874
|
+
if dyn_shape:
|
|
875
|
+
# Checkout whether the `sens` has been added to args_list.
|
|
876
|
+
if len(compile_args) == len(args_list) - 1:
|
|
877
|
+
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
878
|
+
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
879
|
+
f"be 'sens' and added it to compile args.")
|
|
880
|
+
compile_args.append(args_list[-1])
|
|
881
|
+
compile_args = tuple(compile_args)
|
|
882
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
883
|
+
if self.obj is not None:
|
|
884
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
885
|
+
else:
|
|
886
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
887
|
+
else:
|
|
888
|
+
if not verify_inputs_signature(compile_args, args_list):
|
|
889
|
+
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
890
|
+
return compile_args
|
|
891
|
+
|
|
892
|
+
def _check_set_inputs(self):
|
|
893
|
+
"""Check if the `set_inputs()` of Cell object has been set."""
|
|
894
|
+
return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
|
|
895
|
+
|
|
834
896
|
def _generate_compile_args(self, args_list):
|
|
835
897
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
836
|
-
# Case:
|
|
837
|
-
|
|
898
|
+
# Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
|
|
899
|
+
if self.enable_jit_dynamic and self._check_set_inputs():
|
|
900
|
+
raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
|
|
901
|
+
# Case: The `enable_dynamic` is provided.
|
|
902
|
+
if self.enable_jit_dynamic:
|
|
903
|
+
return self._generate_compile_args_by_enable_dynamic(args_list)
|
|
838
904
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
839
|
-
if self.
|
|
840
|
-
|
|
841
|
-
if len(compile_args) != len(args_list):
|
|
842
|
-
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
843
|
-
f"dynamic shape tensors: {len(compile_args)}.")
|
|
844
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
845
|
-
Validator.check_symbolic_shape(compile_args, args_list)
|
|
846
|
-
|
|
905
|
+
if self._check_set_inputs():
|
|
906
|
+
return self._generate_compile_args_by_set_inputs(args_list)
|
|
847
907
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
848
908
|
if self.input_signature is not None:
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
if dyn_shape:
|
|
853
|
-
# Checkout whether the `sens` has been added to args_list.
|
|
854
|
-
if len(compile_args) == len(args_list) - 1:
|
|
855
|
-
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
856
|
-
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
857
|
-
f"be 'sens' and added it to compile args.")
|
|
858
|
-
compile_args.append(args_list[-1])
|
|
859
|
-
compile_args = tuple(compile_args)
|
|
860
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
861
|
-
if self.obj is not None:
|
|
862
|
-
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
863
|
-
else:
|
|
864
|
-
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
865
|
-
else:
|
|
866
|
-
if not verify_inputs_signature(compile_args, args_list):
|
|
867
|
-
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
868
|
-
return compile_args
|
|
909
|
+
return self._generate_compile_args_by_input_signature(args_list)
|
|
910
|
+
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
911
|
+
return _pynative_executor.get_dynamic_input(args_list)
|
|
869
912
|
|
|
870
|
-
def _generate_run_args(self, args_list, kwargs):
|
|
913
|
+
def _generate_run_args(self, args_list, kwargs, is_predict=False):
|
|
871
914
|
"""
|
|
872
915
|
Generate input args, which are required for running.
|
|
873
916
|
|
|
@@ -878,7 +921,11 @@ class _JitExecutor:
|
|
|
878
921
|
Returns:
|
|
879
922
|
new_inputs, new input args, which are required for running.
|
|
880
923
|
"""
|
|
881
|
-
|
|
924
|
+
if self.obj is not None and hasattr(self.obj, '_mutable_flags'):
|
|
925
|
+
mutable_flags = self.obj._mutable_flags
|
|
926
|
+
else:
|
|
927
|
+
mutable_flags = self._mutable_flags
|
|
928
|
+
return _get_args_for_run(self, args_list, kwargs, mutable_flags, is_predict)
|
|
882
929
|
|
|
883
930
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
884
931
|
"""Get graph proto from pipeline."""
|
|
@@ -950,7 +997,7 @@ def _check_option_backend(option, backend):
|
|
|
950
997
|
'ge_options': ['GE'],
|
|
951
998
|
'infer_boost': ['ms_backend'],
|
|
952
999
|
}
|
|
953
|
-
if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
|
|
1000
|
+
if option in option_backend_cfgs and backend != '' and backend not in option_backend_cfgs[option]:
|
|
954
1001
|
logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
|
|
955
1002
|
f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
|
|
956
1003
|
|
|
@@ -1077,10 +1124,7 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
|
1077
1124
|
process_obj = args[0]
|
|
1078
1125
|
# Handle auto mixed precision strategy.
|
|
1079
1126
|
if not hasattr(func, "amp_strategy"):
|
|
1080
|
-
|
|
1081
|
-
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1082
|
-
else:
|
|
1083
|
-
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1127
|
+
setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
|
|
1084
1128
|
|
|
1085
1129
|
jit_graph_name = ''
|
|
1086
1130
|
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
@@ -1088,6 +1132,8 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
|
1088
1132
|
jit_executor = _JitExecutor(
|
|
1089
1133
|
func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
1090
1134
|
out = jit_executor(*args, **kwargs)
|
|
1135
|
+
if isinstance(process_obj, ms.nn.Cell):
|
|
1136
|
+
_clear_auto_parallel_context(process_obj)
|
|
1091
1137
|
return out
|
|
1092
1138
|
|
|
1093
1139
|
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
@@ -1127,28 +1173,26 @@ def jit(
|
|
|
1127
1173
|
|
|
1128
1174
|
Keyword Args:
|
|
1129
1175
|
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1130
|
-
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1176
|
+
should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
|
|
1131
1177
|
|
|
1132
|
-
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
change and/or deletion.
|
|
1137
|
-
- `trace <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#trace>`_ : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1138
|
-
subject to change and/or deletion.
|
|
1178
|
+
- ast: Parse Python ast to build graph.
|
|
1179
|
+
- bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
|
|
1180
|
+
that is subject to change and/or deletion.
|
|
1181
|
+
- trace: Trace the execution of Python code to build graph. This is an experimental prototype
|
|
1182
|
+
that is subject to change and/or deletion.
|
|
1139
1183
|
|
|
1140
1184
|
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1141
|
-
with ms_backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1185
|
+
with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
|
|
1142
1186
|
|
|
1143
|
-
-
|
|
1144
|
-
-
|
|
1187
|
+
- O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1188
|
+
- O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1145
1189
|
level is experimental and is being improved.
|
|
1146
1190
|
|
|
1147
1191
|
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1148
1192
|
is as follows:
|
|
1149
1193
|
|
|
1150
|
-
-
|
|
1151
|
-
-
|
|
1194
|
+
- 0: Do not perform dynamic shape compilation.
|
|
1195
|
+
- 1: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1152
1196
|
|
|
1153
1197
|
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1154
1198
|
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
@@ -1156,12 +1200,16 @@ def jit(
|
|
|
1156
1200
|
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
1157
1201
|
or ``bytecode``. Default: ``False``.
|
|
1158
1202
|
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1159
|
-
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including
|
|
1160
|
-
A2 training series products by default.
|
|
1203
|
+
use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
|
|
1204
|
+
Atlas A2 training series products by default.
|
|
1161
1205
|
|
|
1162
|
-
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1206
|
+
- ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
|
|
1207
|
+
optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
|
|
1208
|
+
- GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
1209
|
+
for Ascend model compilation and execution. Note: This backend only supports GRAPH Mode in Ascend,
|
|
1210
|
+
only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
|
|
1211
|
+
dynamic shape scenes. In addition, this backend incurs additional compilation costs and is difficult to
|
|
1212
|
+
debug and tune.
|
|
1165
1213
|
|
|
1166
1214
|
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1167
1215
|
|
|
@@ -1184,11 +1232,11 @@ def jit(
|
|
|
1184
1232
|
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1185
1233
|
Default: ``False`` .
|
|
1186
1234
|
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1187
|
-
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1235
|
+
methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
|
|
1188
1236
|
|
|
1189
|
-
-
|
|
1237
|
+
- bfs: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1190
1238
|
performance.
|
|
1191
|
-
-
|
|
1239
|
+
- dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1192
1240
|
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1193
1241
|
other execution orders run out of memory (OOM).
|
|
1194
1242
|
|
|
@@ -1199,11 +1247,11 @@ def jit(
|
|
|
1199
1247
|
- global (dict): Set global options.
|
|
1200
1248
|
- session (dict): Set session options.
|
|
1201
1249
|
|
|
1202
|
-
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1250
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
|
|
1203
1251
|
the inference mode is disabled. The range is as follows:
|
|
1204
1252
|
|
|
1205
|
-
-
|
|
1206
|
-
-
|
|
1253
|
+
- on: Enable inference mode, get better infer performance.
|
|
1254
|
+
- off: Disable inference mode, use forward for inference. The performance is poor.
|
|
1207
1255
|
|
|
1208
1256
|
Returns:
|
|
1209
1257
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -1306,9 +1354,8 @@ def jit(
|
|
|
1306
1354
|
jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
|
|
1307
1355
|
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
1308
1356
|
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
1309
|
-
if backend
|
|
1310
|
-
backend =
|
|
1311
|
-
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1357
|
+
if backend != "":
|
|
1358
|
+
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1312
1359
|
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
1313
1360
|
hash_obj = _get_hash_obj(options)
|
|
1314
1361
|
_check_options(options, backend)
|
|
@@ -1323,7 +1370,7 @@ def jit(
|
|
|
1323
1370
|
elif capture_mode == "bytecode":
|
|
1324
1371
|
wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
|
|
1325
1372
|
else:
|
|
1326
|
-
wrap_func = _jit_trace()
|
|
1373
|
+
wrap_func = _jit_trace(jit_config)
|
|
1327
1374
|
|
|
1328
1375
|
if function is not None:
|
|
1329
1376
|
return wrap_func(function)
|
|
@@ -1530,6 +1577,20 @@ def _parameter_broadcast(obj):
|
|
|
1530
1577
|
_build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
|
1531
1578
|
|
|
1532
1579
|
|
|
1580
|
+
def _run_in_jit():
|
|
1581
|
+
"""In jit, this function always returns true. Otherwise, returns false."""
|
|
1582
|
+
def _temp_func():
|
|
1583
|
+
return 0
|
|
1584
|
+
|
|
1585
|
+
from mindspore.ops.primitive import constexpr
|
|
1586
|
+
|
|
1587
|
+
@constexpr(check=False)
|
|
1588
|
+
def _check_func(func):
|
|
1589
|
+
return func is None
|
|
1590
|
+
|
|
1591
|
+
return _check_func(_temp_func)
|
|
1592
|
+
|
|
1593
|
+
|
|
1533
1594
|
class _no_grad(contextlib.ContextDecorator):
|
|
1534
1595
|
"""
|
|
1535
1596
|
Context Manager to disable gradient calculation. When enter this context, we will disable calculate
|
|
@@ -1799,17 +1860,16 @@ class _PyNativeExecutor:
|
|
|
1799
1860
|
"""
|
|
1800
1861
|
return self._executor.requires_grad()
|
|
1801
1862
|
|
|
1802
|
-
def
|
|
1863
|
+
def set_jit_compile_phase(self, phase):
|
|
1803
1864
|
"""
|
|
1804
|
-
Set jit
|
|
1865
|
+
Set jit phase
|
|
1805
1866
|
|
|
1806
1867
|
Args:
|
|
1807
|
-
status(bool): jit compile status
|
|
1808
1868
|
phase (str): The phase of cell/function instance.
|
|
1809
1869
|
Return:
|
|
1810
1870
|
None.
|
|
1811
1871
|
"""
|
|
1812
|
-
self._executor.
|
|
1872
|
+
self._executor.set_jit_compile_phase(phase)
|
|
1813
1873
|
|
|
1814
1874
|
def set_is_run_recompute(self, status):
|
|
1815
1875
|
"""
|
|
@@ -1894,6 +1954,32 @@ class _PyNativeExecutor:
|
|
|
1894
1954
|
"""
|
|
1895
1955
|
return self._executor.constant_folding(*args)
|
|
1896
1956
|
|
|
1957
|
+
def set_creation_type(self, tensor, creation_type):
|
|
1958
|
+
"""
|
|
1959
|
+
Set tensor's view creation type
|
|
1960
|
+
|
|
1961
|
+
Args:
|
|
1962
|
+
tensor (Tensor): input tensor.
|
|
1963
|
+
creation_type (CreationType): The type of view tensor when it is created.
|
|
1964
|
+
|
|
1965
|
+
Return:
|
|
1966
|
+
None.
|
|
1967
|
+
"""
|
|
1968
|
+
return self._executor.set_creation_type(tensor, creation_type)
|
|
1969
|
+
|
|
1970
|
+
def queue_backward_final_callback(self, callback):
|
|
1971
|
+
"""
|
|
1972
|
+
add backward final callback
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
callback(Function): callback function.
|
|
1976
|
+
|
|
1977
|
+
Return:
|
|
1978
|
+
None.
|
|
1979
|
+
"""
|
|
1980
|
+
return self._executor.queue_backward_final_callback(callback)
|
|
1981
|
+
|
|
1982
|
+
|
|
1897
1983
|
|
|
1898
1984
|
class _CellGraphExecutor:
|
|
1899
1985
|
"""
|
|
@@ -2002,6 +2088,11 @@ class _CellGraphExecutor:
|
|
|
2002
2088
|
if not hasattr(obj, obj.__parse_method__):
|
|
2003
2089
|
raise AttributeError(
|
|
2004
2090
|
'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
2091
|
+
inner_func = inspect.unwrap(obj.construct)
|
|
2092
|
+
if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
|
|
2093
|
+
raise ValueError(
|
|
2094
|
+
"When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
|
|
2095
|
+
)
|
|
2005
2096
|
key_id = str(id(obj)) + str(obj.create_time)
|
|
2006
2097
|
args = get_auto_dynamic_shape_args(args, key_id)
|
|
2007
2098
|
|
|
@@ -2012,20 +2103,27 @@ class _CellGraphExecutor:
|
|
|
2012
2103
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
2013
2104
|
|
|
2014
2105
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
|
|
2106
|
+
key = str(key)
|
|
2018
2107
|
|
|
2019
2108
|
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
2020
2109
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
2021
2110
|
if parameter_ids != "":
|
|
2022
|
-
|
|
2111
|
+
key += '.' + parameter_ids
|
|
2112
|
+
|
|
2113
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
2114
|
+
key += "." + str(_hook_version())
|
|
2115
|
+
|
|
2116
|
+
obj.arguments_key = key
|
|
2117
|
+
|
|
2023
2118
|
raw_phase = phase
|
|
2024
|
-
|
|
2119
|
+
|
|
2120
|
+
phase = _real_phase(phase, obj)
|
|
2025
2121
|
obj.phase_cache[raw_phase] = phase
|
|
2026
2122
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
2027
2123
|
obj.current_phase = phase
|
|
2028
|
-
|
|
2124
|
+
obj._add_attr("compile_phase", phase)
|
|
2125
|
+
obj.compile_phase = phase
|
|
2126
|
+
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
2029
2127
|
logger.debug("%r graph has existed.", phase)
|
|
2030
2128
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
2031
2129
|
# generated in generate_arguments_key.
|
|
@@ -2051,7 +2149,6 @@ class _CellGraphExecutor:
|
|
|
2051
2149
|
obj.compile_cache.add(phase)
|
|
2052
2150
|
if not result:
|
|
2053
2151
|
raise RuntimeError("Executor compile failed.")
|
|
2054
|
-
set_parameter_hook_updated(False)
|
|
2055
2152
|
graph = self._graph_executor.get_func_graph(phase)
|
|
2056
2153
|
|
|
2057
2154
|
if graph is None:
|
|
@@ -2075,16 +2172,20 @@ class _CellGraphExecutor:
|
|
|
2075
2172
|
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
|
|
2076
2173
|
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
|
2077
2174
|
|
|
2175
|
+
def set_real_args(self, args, kwargs):
|
|
2176
|
+
"""Set real arguments to graph executor."""
|
|
2177
|
+
self._graph_executor.set_real_args(args, kwargs)
|
|
2178
|
+
|
|
2078
2179
|
def _get_shard_strategy(self, obj):
|
|
2079
|
-
real_phase = obj.phase
|
|
2180
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
2080
2181
|
return self._graph_executor.get_strategy(real_phase)
|
|
2081
2182
|
|
|
2082
2183
|
def _get_num_parallel_ops(self, obj):
|
|
2083
|
-
real_phase = obj.phase
|
|
2184
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
2084
2185
|
return self._graph_executor.get_num_parallel_ops(real_phase)
|
|
2085
2186
|
|
|
2086
2187
|
def _get_allreduce_fusion(self, obj):
|
|
2087
|
-
real_phase = obj.phase
|
|
2188
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
2088
2189
|
return self._graph_executor.get_allreduce_fusion(real_phase)
|
|
2089
2190
|
|
|
2090
2191
|
def __call__(self, obj, *args, phase='predict'):
|
|
@@ -2136,10 +2237,10 @@ class _CellGraphExecutor:
|
|
|
2136
2237
|
Tensor/Tuple, return execute result.
|
|
2137
2238
|
"""
|
|
2138
2239
|
if phase == 'save':
|
|
2139
|
-
exe_phase = phase
|
|
2240
|
+
exe_phase = _real_phase(phase, obj)
|
|
2140
2241
|
return self._graph_executor((), exe_phase)
|
|
2141
2242
|
|
|
2142
|
-
phase_real = phase
|
|
2243
|
+
phase_real = _real_phase(phase, obj)
|
|
2143
2244
|
if self.has_compiled(phase_real):
|
|
2144
2245
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
2145
2246
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
@@ -2164,9 +2265,22 @@ class _CellGraphExecutor:
|
|
|
2164
2265
|
return None
|
|
2165
2266
|
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
2166
2267
|
|
|
2268
|
+
def _get_onnx_func_graph_proto(self, obj, exec_id, use_prefix=False, input_names=None, output_names=None,
|
|
2269
|
+
opset_version=11, export_params=True, keep_initializers_as_inputs=False,
|
|
2270
|
+
dynamic_axes=None, extra_save_params=False, save_file_dir=None):
|
|
2271
|
+
"""Get graph proto from pipeline."""
|
|
2272
|
+
if use_prefix:
|
|
2273
|
+
exec_id = exec_id + '.' + obj.arguments_key
|
|
2274
|
+
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2275
|
+
return None
|
|
2276
|
+
|
|
2277
|
+
return self._graph_executor.get_onnx_func_graph_proto(exec_id, input_names, output_names, opset_version,
|
|
2278
|
+
export_params, keep_initializers_as_inputs, dynamic_axes,
|
|
2279
|
+
extra_save_params, save_file_dir)
|
|
2280
|
+
|
|
2167
2281
|
def get_optimize_graph_proto(self, obj):
|
|
2168
2282
|
"""Return optimize graph binary proto."""
|
|
2169
|
-
exec_id = obj.phase
|
|
2283
|
+
exec_id = _real_phase(obj.phase, obj)
|
|
2170
2284
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2171
2285
|
return None
|
|
2172
2286
|
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
|
|
@@ -2246,5 +2360,190 @@ def flops_collection(phase='train'):
|
|
|
2246
2360
|
return _cell_graph_executor.flops_collection(phase)
|
|
2247
2361
|
|
|
2248
2362
|
|
|
2363
|
+
class _ScriptGraph:
|
|
2364
|
+
"""Store the graph compiled by the frontend compiler."""
|
|
2365
|
+
def __init__(self, func_graph, func, origin_cell, mutable_flags, phase, enable_tuple_broaden):
|
|
2366
|
+
self.func_graph = func_graph
|
|
2367
|
+
self.func = func
|
|
2368
|
+
self.origin_cell = origin_cell
|
|
2369
|
+
self.mutable_flags = mutable_flags
|
|
2370
|
+
self.phase = phase
|
|
2371
|
+
self.enable_tuple_broaden = enable_tuple_broaden
|
|
2372
|
+
|
|
2373
|
+
def print(self):
|
|
2374
|
+
"""Print the MindIR of the frontend graph."""
|
|
2375
|
+
graph_str = dump_func_graph(self.func_graph)
|
|
2376
|
+
print(graph_str, flush=True)
|
|
2377
|
+
|
|
2378
|
+
|
|
2379
|
+
def _frontend_compile_ast(dynamic, jit_config, jit_graph_name=''):
|
|
2380
|
+
"""Return the wrapped function for ast mode jit."""
|
|
2381
|
+
def wrap_func(func):
|
|
2382
|
+
if hasattr(func, "construct") and isinstance(func, ms.nn.Cell):
|
|
2383
|
+
# Bound the cell object to get the self arg.
|
|
2384
|
+
return types.MethodType(_frontend_compile_ast(dynamic, jit_config,
|
|
2385
|
+
func._jit_graph_name)(func.construct.__func__), func)
|
|
2386
|
+
|
|
2387
|
+
if isinstance(func, types.MethodType):
|
|
2388
|
+
return types.MethodType(_frontend_compile_ast(dynamic, jit_config)(func.__func__), func.__self__)
|
|
2389
|
+
|
|
2390
|
+
if not isinstance(func, types.FunctionType):
|
|
2391
|
+
logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
|
|
2392
|
+
return func
|
|
2393
|
+
|
|
2394
|
+
hash_obj = int(time.time() * 1e9)
|
|
2395
|
+
|
|
2396
|
+
@wraps(func)
|
|
2397
|
+
def staging_specialize(*args, **kwargs):
|
|
2398
|
+
if os.getenv("MS_JIT") == '0':
|
|
2399
|
+
return func(*args, **kwargs)
|
|
2400
|
+
|
|
2401
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
2402
|
+
process_obj = None
|
|
2403
|
+
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
2404
|
+
process_obj = args[0]
|
|
2405
|
+
# Handle auto mixed precision strategy.
|
|
2406
|
+
if not hasattr(func, "amp_strategy"):
|
|
2407
|
+
setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
|
|
2408
|
+
|
|
2409
|
+
jit_graph_name = ''
|
|
2410
|
+
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
2411
|
+
jit_graph_name = staging_specialize.__jit_graph_name__
|
|
2412
|
+
jit_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
2413
|
+
func_graph, mutable_flags, phase, enable_tuple_broaden = jit_executor.compile_frontend(*args, **kwargs)
|
|
2414
|
+
return _ScriptGraph(func_graph, func, process_obj, mutable_flags, phase, enable_tuple_broaden)
|
|
2415
|
+
|
|
2416
|
+
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
2417
|
+
# `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
|
|
2418
|
+
# original `func`.
|
|
2419
|
+
staging_specialize.__signature__ = inspect.signature(func)
|
|
2420
|
+
setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
|
|
2421
|
+
return staging_specialize
|
|
2422
|
+
|
|
2423
|
+
return wrap_func
|
|
2424
|
+
|
|
2425
|
+
|
|
2426
|
+
def _frontend_compile(function: Callable,
|
|
2427
|
+
*,
|
|
2428
|
+
dynamic: int = 0,
|
|
2429
|
+
fullgraph: bool = False):
|
|
2430
|
+
"""
|
|
2431
|
+
Create a frontend MindSpore graph from a Python function by the ast capture mode.
|
|
2432
|
+
|
|
2433
|
+
Args:
|
|
2434
|
+
function (Callable, optional): The Python function or Cell instance that will be compiled as a frontend graph.
|
|
2435
|
+
Default: ``None``.
|
|
2436
|
+
|
|
2437
|
+
Keyword Args:
|
|
2438
|
+
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
2439
|
+
is as follows:
|
|
2440
|
+
|
|
2441
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
2442
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
2443
|
+
|
|
2444
|
+
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
2445
|
+
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
2446
|
+
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
2447
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
2448
|
+
or ``bytecode``. Default: ``False``.
|
|
2449
|
+
|
|
2450
|
+
Returns:
|
|
2451
|
+
a :class:`_ScriptGraph` object.
|
|
2452
|
+
|
|
2453
|
+
Supported Platforms:
|
|
2454
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2455
|
+
|
|
2456
|
+
Examples:
|
|
2457
|
+
>>> import numpy as np
|
|
2458
|
+
>>> from mindspore import Tensor
|
|
2459
|
+
>>> from mindspore import ops
|
|
2460
|
+
>>> from mindspore.common.api import _frontend_compile
|
|
2461
|
+
...
|
|
2462
|
+
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2463
|
+
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2464
|
+
...
|
|
2465
|
+
>>> def tensor_add(x, y):
|
|
2466
|
+
... z = x + y
|
|
2467
|
+
... return z
|
|
2468
|
+
...
|
|
2469
|
+
>>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
|
|
2470
|
+
>>> tensor_add_graph.print()
|
|
2471
|
+
...
|
|
2472
|
+
"""
|
|
2473
|
+
|
|
2474
|
+
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
2475
|
+
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
2476
|
+
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
2477
|
+
jit_config = JitConfig(jit_syntax_level=jit_syntax_level)
|
|
2478
|
+
return _frontend_compile_ast(dynamic, jit_config)(function)
|
|
2479
|
+
|
|
2480
|
+
|
|
2481
|
+
class _GraphFragment(_GraphFragment_):
|
|
2482
|
+
"""
|
|
2483
|
+
Represents the output by backend graph split.
|
|
2484
|
+
"""
|
|
2485
|
+
def __init__(self, frag):
|
|
2486
|
+
if frag is None or not isinstance(frag, _GraphFragment_):
|
|
2487
|
+
raise TypeError(f"Expect input `frag` to be a _GraphFragment_, but got {type(frag)}")
|
|
2488
|
+
_GraphFragment_.__init__(self, frag)
|
|
2489
|
+
|
|
2490
|
+
def __call__(self, *args):
|
|
2491
|
+
return super().__call__(args)
|
|
2492
|
+
|
|
2493
|
+
def __repr__(self):
|
|
2494
|
+
return self.__str__()
|
|
2495
|
+
|
|
2496
|
+
def id(self):
|
|
2497
|
+
return self.id_()
|
|
2498
|
+
|
|
2499
|
+
def is_graph(self):
|
|
2500
|
+
return self.is_graph_()
|
|
2501
|
+
|
|
2502
|
+
def py_key(self):
|
|
2503
|
+
return self.py_key_()
|
|
2504
|
+
|
|
2505
|
+
def args_list(self):
|
|
2506
|
+
return self.args_list_()
|
|
2507
|
+
|
|
2508
|
+
|
|
2509
|
+
def _graph_split(script_graph):
|
|
2510
|
+
"""
|
|
2511
|
+
Split the script_graph into several fragments according to the nodes with the split op attribute.
|
|
2512
|
+
|
|
2513
|
+
Args:
|
|
2514
|
+
a :class:`_ScriptGraph` object.
|
|
2515
|
+
|
|
2516
|
+
Returns:
|
|
2517
|
+
several :class:`_GraphFragment` object.
|
|
2518
|
+
|
|
2519
|
+
Supported Platforms:
|
|
2520
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2521
|
+
|
|
2522
|
+
Examples:
|
|
2523
|
+
>>> import numpy as np
|
|
2524
|
+
>>> from mindspore import Tensor
|
|
2525
|
+
>>> from mindspore import ops
|
|
2526
|
+
>>> from mindspore.common.api import _frontend_compile, _graph_split
|
|
2527
|
+
...
|
|
2528
|
+
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2529
|
+
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2530
|
+
>>> add = ops.Add().add_prim_attr("split_op", True).add_prim_attr("func_id", "add_func")
|
|
2531
|
+
...
|
|
2532
|
+
>>> def tensor_add(x, y):
|
|
2533
|
+
... z1 = x + y
|
|
2534
|
+
... z2 = add(z1, x)
|
|
2535
|
+
... return z2
|
|
2536
|
+
...
|
|
2537
|
+
>>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
|
|
2538
|
+
>>> frags = _graph_split(tensor_add_graph)
|
|
2539
|
+
>>> print(frags)
|
|
2540
|
+
...
|
|
2541
|
+
"""
|
|
2542
|
+
outputs = JitExecutor_.get_instance().split_graph(script_graph.func_graph)
|
|
2543
|
+
fragments = []
|
|
2544
|
+
for arg in outputs:
|
|
2545
|
+
fragments.append(_GraphFragment(arg))
|
|
2546
|
+
return fragments
|
|
2547
|
+
|
|
2249
2548
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2250
2549
|
_pynative_executor = _PyNativeExecutor()
|