mindspore 2.6.0__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +64 -83
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +47 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +177 -52
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +338 -208
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +2 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +84 -133
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +47 -38
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +69 -23
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +1 -0
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +4 -44
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +425 -19
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +125 -101
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +488 -620
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +86 -85
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +2 -4
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +1 -1
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +6 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
- mindspore/ops/auto_generate/gen_extend_func.py +5 -55
- mindspore/ops/auto_generate/gen_ops_def.py +753 -273
- mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +17 -100
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +65 -399
- mindspore/ops/function/nn_func.py +44 -61
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +31 -4
- mindspore/ops/functional.py +2 -3
- mindspore/ops/functional_overload.py +486 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_custom_ops_utils.py +675 -8
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +4 -50
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +244 -175
- mindspore/ops/operations/debug_ops.py +55 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +6 -7
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +52 -11
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +7 -2
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +9 -17
- mindspore/parallel/_cell_wrapper.py +106 -40
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +33 -12
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +5 -1
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +7 -6
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +9 -23
- mindspore/parallel/transform_safetensors.py +468 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +22 -30
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +9 -5
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +72 -18
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +49 -47
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +162 -78
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|
|
21
|
+
|
|
20
22
|
import gc
|
|
21
23
|
import types
|
|
22
24
|
import sys
|
|
@@ -42,23 +44,25 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
|
42
44
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
43
45
|
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
44
46
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
45
|
-
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx,
|
|
47
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
|
|
46
48
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
47
49
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
50
|
_is_parallel_mode
|
|
49
51
|
from mindspore import _checkparam as Validator
|
|
50
52
|
from mindspore._checkparam import is_stub_tensor
|
|
51
|
-
from mindspore.common._utils import is_shape_unknown
|
|
53
|
+
from mindspore.common._utils import is_shape_unknown, get_func
|
|
52
54
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
53
|
-
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args,
|
|
54
|
-
|
|
55
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
|
|
56
|
+
update_auto_dynamic_shape_phase
|
|
57
|
+
from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
|
|
55
58
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
-
from mindspore.common.parameter import Parameter
|
|
59
|
+
from mindspore.common.parameter import Parameter
|
|
60
|
+
from mindspore.common.hook_handle import _hook_version
|
|
57
61
|
from mindspore.common.jit_context import jit_context
|
|
58
62
|
from mindspore.common.jit_trace import _jit_trace
|
|
59
63
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
60
64
|
|
|
61
|
-
# Store
|
|
65
|
+
# Store jit class compiled pipeline cache.
|
|
62
66
|
ms_compile_cache = set()
|
|
63
67
|
# Store cell compiled pipeline cache.
|
|
64
68
|
cells_compile_cache = {}
|
|
@@ -72,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
|
|
|
72
76
|
TOTAL_ARG_LEN = "total_arg_length"
|
|
73
77
|
|
|
74
78
|
|
|
79
|
+
def _real_phase(phase, obj):
|
|
80
|
+
real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
81
|
+
return real_phase
|
|
82
|
+
|
|
83
|
+
|
|
75
84
|
def _check_recompile_args(compile_args, kwargs):
|
|
76
85
|
"""Check recompile of graph"""
|
|
77
86
|
|
|
@@ -134,8 +143,6 @@ def _convert_python_data(data):
|
|
|
134
143
|
"""
|
|
135
144
|
if isinstance(data, PythonTensor):
|
|
136
145
|
return data
|
|
137
|
-
if isinstance(data, StubNode):
|
|
138
|
-
return ms.common._stub_tensor._convert_stub(data)
|
|
139
146
|
if data.__class__ is tuple:
|
|
140
147
|
# Handle namedtuple since its type is tuple.
|
|
141
148
|
if hasattr(data, "_fields"):
|
|
@@ -278,13 +285,13 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
278
285
|
module = importlib.util.module_from_spec(module_spec)
|
|
279
286
|
if hasattr(module, '__file__'):
|
|
280
287
|
dep_file_path = module.__file__
|
|
288
|
+
# Exclude the installed modules.
|
|
289
|
+
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
290
|
+
logger.debug(f"dependent file path: {dep_file_path}")
|
|
291
|
+
compile_cache_dep_files.append(dep_file_path)
|
|
292
|
+
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
281
293
|
else:
|
|
282
294
|
continue
|
|
283
|
-
# Exclude the installed modules.
|
|
284
|
-
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
285
|
-
logger.debug(f"dependent file path: {dep_file_path}")
|
|
286
|
-
compile_cache_dep_files.append(dep_file_path)
|
|
287
|
-
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
288
295
|
|
|
289
296
|
|
|
290
297
|
def _get_compile_cache_dep_files():
|
|
@@ -342,7 +349,7 @@ def _get_parameter_layout():
|
|
|
342
349
|
return layout
|
|
343
350
|
|
|
344
351
|
|
|
345
|
-
def _handle_arg(obj, arg,
|
|
352
|
+
def _handle_arg(obj, arg, has_mutable_arg):
|
|
346
353
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
347
354
|
from mindspore._extends.parse import compile_config
|
|
348
355
|
if isinstance(arg, PythonTensor):
|
|
@@ -352,7 +359,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
352
359
|
return arg
|
|
353
360
|
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
|
354
361
|
return arg
|
|
355
|
-
elif
|
|
362
|
+
elif has_mutable_arg:
|
|
356
363
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
357
364
|
if isinstance(arg, list) and not arg:
|
|
358
365
|
return None
|
|
@@ -366,7 +373,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
366
373
|
return None
|
|
367
374
|
|
|
368
375
|
|
|
369
|
-
def _handle_arg_predict(obj, arg,
|
|
376
|
+
def _handle_arg_predict(obj, arg, has_mutable_arg):
|
|
370
377
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
371
378
|
if arg is None:
|
|
372
379
|
return None
|
|
@@ -375,8 +382,7 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
375
382
|
return None
|
|
376
383
|
|
|
377
384
|
if isinstance(arg, (list, tuple)):
|
|
378
|
-
if
|
|
379
|
-
getattr(compile_arg, "__ms_mutable__"):
|
|
385
|
+
if has_mutable_arg:
|
|
380
386
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
381
387
|
if isinstance(arg, list) and not arg:
|
|
382
388
|
return None
|
|
@@ -388,35 +394,30 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
388
394
|
return arg
|
|
389
395
|
|
|
390
396
|
|
|
391
|
-
def _get_args_for_run(obj, args, kwargs,
|
|
397
|
+
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
|
|
392
398
|
"""Get the actual input args and kwargs for runtime."""
|
|
393
399
|
new_args = []
|
|
394
|
-
|
|
395
|
-
|
|
400
|
+
fn = _handle_arg_predict if is_predict else _handle_arg
|
|
401
|
+
for arg, has_mutable_arg in zip(args, has_mutable_args_list):
|
|
402
|
+
new_arg = fn(obj, arg, has_mutable_arg)
|
|
396
403
|
if new_arg is not None:
|
|
397
404
|
new_args.append(new_arg)
|
|
398
405
|
|
|
399
406
|
for _, value in kwargs.items():
|
|
400
|
-
new_value =
|
|
407
|
+
new_value = fn(obj, value, None)
|
|
401
408
|
if new_value is not None:
|
|
402
409
|
new_args.append(new_value)
|
|
403
410
|
|
|
404
411
|
return new_args
|
|
405
412
|
|
|
406
413
|
|
|
407
|
-
def
|
|
408
|
-
"""Get
|
|
414
|
+
def _get_mutable_flags(compile_args):
|
|
415
|
+
"""Get a list of booleans indicating whether each argument is marked as mutable"""
|
|
409
416
|
new_args = []
|
|
410
|
-
for
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
for _, value in kwargs.items():
|
|
416
|
-
new_value = _handle_arg_predict(obj, value, None)
|
|
417
|
-
if new_value is not None:
|
|
418
|
-
new_args.append(new_value)
|
|
419
|
-
|
|
417
|
+
for compile_arg in compile_args:
|
|
418
|
+
has_mutable_arg = compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
|
|
419
|
+
getattr(compile_arg, "__ms_mutable__")
|
|
420
|
+
new_args.append(has_mutable_arg)
|
|
420
421
|
return new_args
|
|
421
422
|
|
|
422
423
|
|
|
@@ -544,10 +545,12 @@ def _get_parameter_ids(args, kwargs):
|
|
|
544
545
|
parameter_ids += str(id(value))
|
|
545
546
|
return parameter_ids
|
|
546
547
|
|
|
548
|
+
|
|
547
549
|
def _get_tensor_hook_key(tensor):
|
|
548
550
|
"""Get the hook key of Tensor/Parameter"""
|
|
549
551
|
return ".".join(map(str, map(id, tensor.hooks())))
|
|
550
552
|
|
|
553
|
+
|
|
551
554
|
def _get_hook_key(*args, **kwargs):
|
|
552
555
|
"""Get the hook key of Tensors/Parameters"""
|
|
553
556
|
hook_key = ""
|
|
@@ -586,13 +589,16 @@ class _JitExecutor:
|
|
|
586
589
|
The result of pipeline running in graph mode.
|
|
587
590
|
"""
|
|
588
591
|
|
|
589
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0
|
|
592
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0,
|
|
593
|
+
cell_cache_key_extend=''):
|
|
590
594
|
init_pipeline()
|
|
591
595
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
592
596
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
593
597
|
|
|
594
598
|
self.fn = fn
|
|
595
599
|
self.input_signature = input_signature
|
|
600
|
+
self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
|
|
601
|
+
self.enable_jit_dynamic = self.dynamic_args_shapes is not None
|
|
596
602
|
self.obj = None
|
|
597
603
|
if obj and hasattr(obj, fn.__name__):
|
|
598
604
|
self.obj = obj
|
|
@@ -606,6 +612,7 @@ class _JitExecutor:
|
|
|
606
612
|
self._compile_args = None
|
|
607
613
|
self._enable_auto_dynamic = dynamic == 1
|
|
608
614
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
615
|
+
self._cell_cache_key_extend = cell_cache_key_extend
|
|
609
616
|
|
|
610
617
|
def _predict(self, *args, **kwargs):
|
|
611
618
|
"""Dedicated routine for predict."""
|
|
@@ -630,15 +637,18 @@ class _JitExecutor:
|
|
|
630
637
|
else: # get compiled args to generate run args by _generate_run_args
|
|
631
638
|
compile_args = self._generate_compile_args(args_list)
|
|
632
639
|
key_id = self._get_key_id()
|
|
633
|
-
|
|
634
|
-
compile_args
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
self._enable_auto_dynamic
|
|
638
|
-
)
|
|
640
|
+
if self.input_signature is None:
|
|
641
|
+
compile_args = get_auto_dynamic_shape_args(
|
|
642
|
+
compile_args, key_id, self._enable_auto_dynamic
|
|
643
|
+
)
|
|
639
644
|
self._compile_args = compile_args
|
|
640
645
|
|
|
641
646
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
647
|
+
if self.jit_config_dict:
|
|
648
|
+
jit_config_dict = self.jit_config_dict
|
|
649
|
+
else:
|
|
650
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
651
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
642
652
|
output = self._graph_executor(
|
|
643
653
|
tuple(new_inputs),
|
|
644
654
|
self.obj.phase_cache[self.obj.phase]
|
|
@@ -658,12 +668,9 @@ class _JitExecutor:
|
|
|
658
668
|
args_list = args_list[1:]
|
|
659
669
|
phase = ""
|
|
660
670
|
try:
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
_pynative_executor.set_jit_compile_status(False, phase)
|
|
665
|
-
else:
|
|
666
|
-
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
671
|
+
_pynative_executor.set_jit_compile_status(True, phase)
|
|
672
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
673
|
+
_pynative_executor.set_jit_compile_status(False, phase)
|
|
667
674
|
except Exception as err:
|
|
668
675
|
_pynative_executor.clear_res()
|
|
669
676
|
raise err
|
|
@@ -672,31 +679,27 @@ class _JitExecutor:
|
|
|
672
679
|
return None
|
|
673
680
|
|
|
674
681
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
675
|
-
if
|
|
676
|
-
|
|
682
|
+
if self.jit_config_dict:
|
|
683
|
+
jit_config_dict = self.jit_config_dict
|
|
677
684
|
else:
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
685
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
686
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
687
|
+
output = _pynative_executor.grad_jit(*new_inputs)
|
|
688
|
+
if jit_context():
|
|
689
|
+
if is_stub_tensor(output):
|
|
690
|
+
output = output.stub_sync()
|
|
691
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
684
692
|
return output
|
|
685
693
|
|
|
686
694
|
def compile(self, method_name, *args, **kwargs):
|
|
687
695
|
"""Returns pipeline for the given args."""
|
|
688
|
-
# Check whether hook function registered on Cell object.
|
|
689
|
-
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
|
690
|
-
if self.obj._hook_fn_registered():
|
|
691
|
-
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
|
|
692
|
-
f"If you want to use hook function, please use context.set_context to set "
|
|
693
|
-
f"pynative mode and remove 'jit' decorator.")
|
|
694
696
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
695
697
|
compile_args = self._generate_compile_args(args)
|
|
696
698
|
key_id = self._get_key_id()
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
699
|
+
if self.input_signature is None:
|
|
700
|
+
compile_args = get_auto_dynamic_shape_args(
|
|
701
|
+
compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
|
|
702
|
+
)
|
|
700
703
|
|
|
701
704
|
# Add mutable for compile_args for two scene:
|
|
702
705
|
# 1) Origin args is mutable.
|
|
@@ -736,18 +739,23 @@ class _JitExecutor:
|
|
|
736
739
|
|
|
737
740
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
738
741
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
742
|
+
key = str(key)
|
|
739
743
|
|
|
740
744
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
741
745
|
if parameter_ids != "":
|
|
742
|
-
key
|
|
746
|
+
key += '.' + parameter_ids
|
|
747
|
+
|
|
748
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
749
|
+
key += "." + str(_hook_version())
|
|
743
750
|
|
|
744
|
-
|
|
751
|
+
phase = generate_name + '.' + key
|
|
745
752
|
|
|
746
|
-
|
|
753
|
+
if self.input_signature is None:
|
|
754
|
+
update_auto_dynamic_shape_phase(compile_args, key_id, phase)
|
|
747
755
|
|
|
748
|
-
|
|
756
|
+
phase = phase + self._cell_cache_key_extend
|
|
749
757
|
|
|
750
|
-
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase)
|
|
758
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
|
|
751
759
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
752
760
|
# generated in generate_arguments_key.
|
|
753
761
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -758,30 +766,23 @@ class _JitExecutor:
|
|
|
758
766
|
# If enable compile cache, get the dependency files list and set to graph executor.
|
|
759
767
|
self._set_compile_cache_dep_files()
|
|
760
768
|
if self.jit_config_dict:
|
|
761
|
-
self.
|
|
769
|
+
jit_config_dict = self.jit_config_dict
|
|
762
770
|
else:
|
|
763
771
|
jit_config_dict = JitConfig().jit_config_dict
|
|
764
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
765
772
|
|
|
766
773
|
if self.obj is None:
|
|
767
774
|
# Set an attribute to fn as an identifier.
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
setattr(self.fn, "__jit_function__", True)
|
|
772
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
|
|
773
|
-
if isinstance(self.fn, types.MethodType):
|
|
774
|
-
delattr(self.fn.__func__, "__jit_function__")
|
|
775
|
-
else:
|
|
776
|
-
delattr(self.fn, "__jit_function__")
|
|
775
|
+
setattr(get_func(self.fn), "__jit_function__", True)
|
|
776
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
777
|
+
delattr(get_func(self.fn), "__jit_function__")
|
|
777
778
|
else:
|
|
778
779
|
if isinstance(self.obj, ms.nn.Cell):
|
|
779
780
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
780
|
-
is_compile = self._graph_executor.compile(
|
|
781
|
+
is_compile = self._graph_executor.compile(
|
|
782
|
+
self.obj, compile_args, kwargs, phase, jit_config_dict)
|
|
781
783
|
|
|
782
784
|
if not is_compile:
|
|
783
785
|
raise RuntimeError("Executor compile failed.")
|
|
784
|
-
set_parameter_hook_updated(False)
|
|
785
786
|
ms_compile_cache.add(phase)
|
|
786
787
|
if hasattr(self.obj, "phase"):
|
|
787
788
|
self.obj.phase_cache[self.obj.phase] = phase
|
|
@@ -829,41 +830,70 @@ class _JitExecutor:
|
|
|
829
830
|
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
830
831
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
831
832
|
|
|
833
|
+
def _generate_compile_args_by_enable_dynamic(self, args_list):
|
|
834
|
+
"""Generate compile args by enable_dynamic."""
|
|
835
|
+
compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
|
|
836
|
+
compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
|
|
837
|
+
if self.obj is not None:
|
|
838
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
839
|
+
else:
|
|
840
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
841
|
+
logger.info(f"dynamic shape compile_args: {compile_args}")
|
|
842
|
+
return compile_args
|
|
843
|
+
|
|
844
|
+
def _generate_compile_args_by_set_inputs(self, args_list):
|
|
845
|
+
"""Generate compile args by set_inputs."""
|
|
846
|
+
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
847
|
+
if len(compile_args) != len(args_list):
|
|
848
|
+
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
849
|
+
f"dynamic shape tensors: {len(compile_args)}.")
|
|
850
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
|
|
851
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
852
|
+
return compile_args
|
|
853
|
+
|
|
854
|
+
def _generate_compile_args_by_input_signature(self, args_list):
|
|
855
|
+
"""Generate compile args by input_signature."""
|
|
856
|
+
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
857
|
+
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
858
|
+
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
859
|
+
if dyn_shape:
|
|
860
|
+
# Checkout whether the `sens` has been added to args_list.
|
|
861
|
+
if len(compile_args) == len(args_list) - 1:
|
|
862
|
+
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
863
|
+
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
864
|
+
f"be 'sens' and added it to compile args.")
|
|
865
|
+
compile_args.append(args_list[-1])
|
|
866
|
+
compile_args = tuple(compile_args)
|
|
867
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
868
|
+
if self.obj is not None:
|
|
869
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
870
|
+
else:
|
|
871
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
872
|
+
else:
|
|
873
|
+
if not verify_inputs_signature(compile_args, args_list):
|
|
874
|
+
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
875
|
+
return compile_args
|
|
876
|
+
|
|
877
|
+
def _check_set_inputs(self):
|
|
878
|
+
"""Check if the `set_inputs()` of Cell object has been set."""
|
|
879
|
+
return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
|
|
880
|
+
|
|
832
881
|
def _generate_compile_args(self, args_list):
|
|
833
882
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
834
|
-
# Case:
|
|
835
|
-
|
|
883
|
+
# Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
|
|
884
|
+
if self.enable_jit_dynamic and self._check_set_inputs():
|
|
885
|
+
raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
|
|
886
|
+
# Case: The `enable_dynamic` is provided.
|
|
887
|
+
if self.enable_jit_dynamic:
|
|
888
|
+
return self._generate_compile_args_by_enable_dynamic(args_list)
|
|
836
889
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
837
|
-
if self.
|
|
838
|
-
|
|
839
|
-
if len(compile_args) != len(args_list):
|
|
840
|
-
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
841
|
-
f"dynamic shape tensors: {len(compile_args)}.")
|
|
842
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
843
|
-
Validator.check_symbolic_shape(compile_args, args_list)
|
|
844
|
-
|
|
890
|
+
if self._check_set_inputs():
|
|
891
|
+
return self._generate_compile_args_by_set_inputs(args_list)
|
|
845
892
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
846
893
|
if self.input_signature is not None:
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
if dyn_shape:
|
|
851
|
-
# Checkout whether the `sens` has been added to args_list.
|
|
852
|
-
if len(compile_args) == len(args_list) - 1:
|
|
853
|
-
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
854
|
-
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
855
|
-
f"be 'sens' and added it to compile args.")
|
|
856
|
-
compile_args.append(args_list[-1])
|
|
857
|
-
compile_args = tuple(compile_args)
|
|
858
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
859
|
-
if self.obj is not None:
|
|
860
|
-
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
861
|
-
else:
|
|
862
|
-
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
863
|
-
else:
|
|
864
|
-
if not verify_inputs_signature(compile_args, args_list):
|
|
865
|
-
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
866
|
-
return compile_args
|
|
894
|
+
return self._generate_compile_args_by_input_signature(args_list)
|
|
895
|
+
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
896
|
+
return _pynative_executor.get_dynamic_input(args_list)
|
|
867
897
|
|
|
868
898
|
def _generate_run_args(self, args_list, kwargs):
|
|
869
899
|
"""
|
|
@@ -876,7 +906,7 @@ class _JitExecutor:
|
|
|
876
906
|
Returns:
|
|
877
907
|
new_inputs, new input args, which are required for running.
|
|
878
908
|
"""
|
|
879
|
-
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
909
|
+
return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
|
|
880
910
|
|
|
881
911
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
882
912
|
"""Get graph proto from pipeline."""
|
|
@@ -1037,6 +1067,67 @@ def _check_options(options, backend):
|
|
|
1037
1067
|
_check_option_value(option, value)
|
|
1038
1068
|
|
|
1039
1069
|
|
|
1070
|
+
def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
1071
|
+
"""Return the wrapped function for ast mode jit."""
|
|
1072
|
+
def wrap_func(func):
|
|
1073
|
+
nonlocal hash_obj
|
|
1074
|
+
if hasattr(func, "construct"):
|
|
1075
|
+
if isinstance(func, ms.nn.Cell):
|
|
1076
|
+
# Bound the cell object to get the self arg.
|
|
1077
|
+
return types.MethodType(_jit_ast(
|
|
1078
|
+
hash_obj, dynamic, jit_config, func._jit_graph_name)(func.construct.__func__), func)
|
|
1079
|
+
if isinstance(func, type) and issubclass(func, ms.nn.Cell):
|
|
1080
|
+
func.construct = _jit_ast(
|
|
1081
|
+
hash_obj, dynamic, jit_config, '')(func.construct)
|
|
1082
|
+
return func
|
|
1083
|
+
|
|
1084
|
+
if isinstance(func, types.MethodType):
|
|
1085
|
+
return types.MethodType(_jit_ast(hash_obj, dynamic, jit_config, '')(func.__func__), func.__self__)
|
|
1086
|
+
|
|
1087
|
+
if not isinstance(func, types.FunctionType):
|
|
1088
|
+
logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
|
|
1089
|
+
return func
|
|
1090
|
+
|
|
1091
|
+
if hasattr(func, "__wrapped_by_jit__"):
|
|
1092
|
+
logger.warning(f"The func {func} should be wrapped by jit only once.")
|
|
1093
|
+
|
|
1094
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
1095
|
+
hash_obj = int(time.time() * 1e9)
|
|
1096
|
+
|
|
1097
|
+
@wraps(func)
|
|
1098
|
+
def staging_specialize(*args, **kwargs):
|
|
1099
|
+
if os.getenv("MS_JIT") == '0':
|
|
1100
|
+
return func(*args, **kwargs)
|
|
1101
|
+
|
|
1102
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1103
|
+
process_obj = None
|
|
1104
|
+
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1105
|
+
process_obj = args[0]
|
|
1106
|
+
# Handle auto mixed precision strategy.
|
|
1107
|
+
if not hasattr(func, "amp_strategy"):
|
|
1108
|
+
setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
|
|
1109
|
+
|
|
1110
|
+
jit_graph_name = ''
|
|
1111
|
+
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
1112
|
+
jit_graph_name = staging_specialize.__jit_graph_name__
|
|
1113
|
+
jit_executor = _JitExecutor(
|
|
1114
|
+
func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
1115
|
+
out = jit_executor(*args, **kwargs)
|
|
1116
|
+
if isinstance(process_obj, ms.nn.Cell):
|
|
1117
|
+
_clear_auto_parallel_context(process_obj)
|
|
1118
|
+
return out
|
|
1119
|
+
|
|
1120
|
+
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
1121
|
+
# `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
|
|
1122
|
+
# original `func`.
|
|
1123
|
+
staging_specialize.__signature__ = inspect.signature(func)
|
|
1124
|
+
setattr(staging_specialize, "__wrapped_by_jit__", True)
|
|
1125
|
+
setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
|
|
1126
|
+
return staging_specialize
|
|
1127
|
+
|
|
1128
|
+
return wrap_func
|
|
1129
|
+
|
|
1130
|
+
|
|
1040
1131
|
def jit(
|
|
1041
1132
|
function: Optional[Callable] = None,
|
|
1042
1133
|
*,
|
|
@@ -1059,45 +1150,45 @@ def jit(
|
|
|
1059
1150
|
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1060
1151
|
|
|
1061
1152
|
Args:
|
|
1062
|
-
function (
|
|
1153
|
+
function (Callable, optional): The Python function or Cell that will be run as a graph. Default: ``None``.
|
|
1063
1154
|
|
|
1064
1155
|
Keyword Args:
|
|
1065
1156
|
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1066
|
-
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1157
|
+
should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
|
|
1067
1158
|
|
|
1068
|
-
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
change and/or deletion.
|
|
1073
|
-
- `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1074
|
-
subject to change and/or deletion.
|
|
1159
|
+
- ast: Parse Python ast to build graph.
|
|
1160
|
+
- bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
|
|
1161
|
+
that is subject to change and/or deletion.
|
|
1162
|
+
- trace: Trace the execution of Python code to build graph. This is an experimental prototype
|
|
1163
|
+
that is subject to change and/or deletion.
|
|
1075
1164
|
|
|
1076
1165
|
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1077
|
-
with
|
|
1166
|
+
with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
|
|
1078
1167
|
|
|
1079
|
-
-
|
|
1080
|
-
-
|
|
1168
|
+
- O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1169
|
+
- O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1081
1170
|
level is experimental and is being improved.
|
|
1082
1171
|
|
|
1083
1172
|
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1084
1173
|
is as follows:
|
|
1085
1174
|
|
|
1086
|
-
-
|
|
1087
|
-
-
|
|
1175
|
+
- 0: Do not perform dynamic shape compilation.
|
|
1176
|
+
- 1: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1088
1177
|
|
|
1089
1178
|
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1090
1179
|
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1091
1180
|
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1092
|
-
not supported), then it will raise an exception. This currently only applies when capture_mode is ast
|
|
1093
|
-
Default: ``False``.
|
|
1181
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
1182
|
+
or ``bytecode``. Default: ``False``.
|
|
1094
1183
|
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1095
|
-
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including
|
|
1096
|
-
A2 training series products by default.
|
|
1184
|
+
use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
|
|
1185
|
+
Atlas A2 training series products by default.
|
|
1097
1186
|
|
|
1098
|
-
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1187
|
+
- ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
|
|
1188
|
+
optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
|
|
1189
|
+
- GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
1190
|
+
for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
|
|
1191
|
+
and can be executed only on Ascend hardware.
|
|
1101
1192
|
|
|
1102
1193
|
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1103
1194
|
|
|
@@ -1120,11 +1211,11 @@ def jit(
|
|
|
1120
1211
|
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1121
1212
|
Default: ``False`` .
|
|
1122
1213
|
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1123
|
-
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1214
|
+
methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
|
|
1124
1215
|
|
|
1125
|
-
-
|
|
1216
|
+
- bfs: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1126
1217
|
performance.
|
|
1127
|
-
-
|
|
1218
|
+
- dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1128
1219
|
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1129
1220
|
other execution orders run out of memory (OOM).
|
|
1130
1221
|
|
|
@@ -1135,11 +1226,11 @@ def jit(
|
|
|
1135
1226
|
- global (dict): Set global options.
|
|
1136
1227
|
- session (dict): Set session options.
|
|
1137
1228
|
|
|
1138
|
-
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1229
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
|
|
1139
1230
|
the inference mode is disabled. The range is as follows:
|
|
1140
1231
|
|
|
1141
|
-
-
|
|
1142
|
-
-
|
|
1232
|
+
- on: Enable inference mode, get better infer performance.
|
|
1233
|
+
- off: Disable inference mode, use forward for inference. The performance is poor.
|
|
1143
1234
|
|
|
1144
1235
|
Returns:
|
|
1145
1236
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -1158,29 +1249,84 @@ def jit(
|
|
|
1158
1249
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1159
1250
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1160
1251
|
...
|
|
1161
|
-
>>> #
|
|
1252
|
+
>>> # Create a callable MindSpore graph by calling jit.
|
|
1162
1253
|
>>> def tensor_add(x, y):
|
|
1163
1254
|
... z = x + y
|
|
1164
1255
|
... return z
|
|
1165
1256
|
...
|
|
1166
1257
|
>>> tensor_add_graph = jit(function=tensor_add)
|
|
1167
1258
|
>>> out = tensor_add_graph(x, y)
|
|
1259
|
+
>>> print(out)
|
|
1260
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1261
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1262
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1263
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1168
1264
|
...
|
|
1169
|
-
>>> #
|
|
1265
|
+
>>> # Create a callable MindSpore graph through decorator @jit.
|
|
1170
1266
|
>>> @jit
|
|
1171
1267
|
... def tensor_add_with_dec(x, y):
|
|
1172
1268
|
... z = x + y
|
|
1173
1269
|
... return z
|
|
1174
1270
|
...
|
|
1175
1271
|
>>> out = tensor_add_with_dec(x, y)
|
|
1272
|
+
>>> print(out)
|
|
1273
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1274
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1275
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1276
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1176
1277
|
...
|
|
1177
|
-
>>> #
|
|
1278
|
+
>>> # Create a callable MindSpore graph and capture the entire function into the graph.
|
|
1178
1279
|
>>> @jit(fullgraph=True)
|
|
1179
1280
|
... def tensor_add_fullgraph(x, y):
|
|
1180
1281
|
... z = x + y
|
|
1181
1282
|
... return z
|
|
1182
1283
|
...
|
|
1183
1284
|
>>> out = tensor_add_fullgraph(x, y)
|
|
1285
|
+
>>> print(out)
|
|
1286
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1287
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1288
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1289
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1290
|
+
...
|
|
1291
|
+
>>> # Create a callable MindSpore graph by trace mode.
|
|
1292
|
+
>>> @jit(capture_mode="trace")
|
|
1293
|
+
... def tensor_add_by_trace(x, y):
|
|
1294
|
+
... z = x + y
|
|
1295
|
+
... return z
|
|
1296
|
+
...
|
|
1297
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1298
|
+
>>> print(out)
|
|
1299
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1300
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1301
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1302
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1303
|
+
...
|
|
1304
|
+
>>> # Create a callable MindSpore graph with ms_backend and jit_level="O1".
|
|
1305
|
+
>>> @jit(backend="ms_backend", jit_level="O1")
|
|
1306
|
+
... def tensor_add_by_trace(x, y):
|
|
1307
|
+
... z = x + y
|
|
1308
|
+
... return z
|
|
1309
|
+
...
|
|
1310
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1311
|
+
>>> print(out)
|
|
1312
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1313
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1314
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1315
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1316
|
+
...
|
|
1317
|
+
>>> # Create a callable MindSpore graph with GE backend and some ge options on Ascend.
|
|
1318
|
+
>>> @jit(backend="GE", ge_options={"global": {"ge.opSelectImplmode": "high_precision"}})
|
|
1319
|
+
... def tensor_add_by_trace(x, y):
|
|
1320
|
+
... z = x + y
|
|
1321
|
+
... return z
|
|
1322
|
+
...
|
|
1323
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1324
|
+
>>> print(out)
|
|
1325
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1326
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1327
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1328
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1329
|
+
...
|
|
1184
1330
|
"""
|
|
1185
1331
|
|
|
1186
1332
|
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
@@ -1199,39 +1345,12 @@ def jit(
|
|
|
1199
1345
|
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1200
1346
|
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1201
1347
|
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
def staging_specialize(*args, **kwargs):
|
|
1209
|
-
if os.getenv("MS_JIT") == '0':
|
|
1210
|
-
return func(*args, **kwargs)
|
|
1211
|
-
|
|
1212
|
-
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1213
|
-
process_obj = None
|
|
1214
|
-
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1215
|
-
process_obj = args[0]
|
|
1216
|
-
# Handle auto mixed precision strategy.
|
|
1217
|
-
if not hasattr(func, "amp_strategy"):
|
|
1218
|
-
if isinstance(func, types.MethodType):
|
|
1219
|
-
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1220
|
-
else:
|
|
1221
|
-
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1222
|
-
|
|
1223
|
-
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1224
|
-
out = ms_function_executor(*args, **kwargs)
|
|
1225
|
-
return out
|
|
1226
|
-
|
|
1227
|
-
return staging_specialize
|
|
1228
|
-
|
|
1229
|
-
if capture_mode == "bytecode":
|
|
1230
|
-
wrap_func = PIJitCaptureContext(jit_config)
|
|
1231
|
-
elif capture_mode == "trace":
|
|
1232
|
-
if function is not None:
|
|
1233
|
-
return _jit_trace(function)
|
|
1234
|
-
return _jit_trace
|
|
1348
|
+
if capture_mode == "ast":
|
|
1349
|
+
wrap_func = _jit_ast(hash_obj, dynamic, jit_config, '')
|
|
1350
|
+
elif capture_mode == "bytecode":
|
|
1351
|
+
wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
|
|
1352
|
+
else:
|
|
1353
|
+
wrap_func = _jit_trace()
|
|
1235
1354
|
|
|
1236
1355
|
if function is not None:
|
|
1237
1356
|
return wrap_func(function)
|
|
@@ -1547,7 +1666,7 @@ class _PyNativeExecutor:
|
|
|
1547
1666
|
"""
|
|
1548
1667
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1549
1668
|
|
|
1550
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1669
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
|
|
1551
1670
|
"""
|
|
1552
1671
|
Whether the forward graph need to construct.
|
|
1553
1672
|
|
|
@@ -1560,7 +1679,7 @@ class _PyNativeExecutor:
|
|
|
1560
1679
|
Return:
|
|
1561
1680
|
bool, specifies whether the forward graph needs to construct.
|
|
1562
1681
|
"""
|
|
1563
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1682
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, **kwargs)
|
|
1564
1683
|
|
|
1565
1684
|
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1566
1685
|
"""
|
|
@@ -1802,6 +1921,19 @@ class _PyNativeExecutor:
|
|
|
1802
1921
|
"""
|
|
1803
1922
|
return self._executor.constant_folding(*args)
|
|
1804
1923
|
|
|
1924
|
+
def set_creation_type(self, tensor, creation_type):
|
|
1925
|
+
"""
|
|
1926
|
+
Set tensor's view creation type
|
|
1927
|
+
|
|
1928
|
+
Args:
|
|
1929
|
+
tensor (Tensor): input tensor.
|
|
1930
|
+
creation_type (CreationType): The type of view tensor when it is created.
|
|
1931
|
+
|
|
1932
|
+
Return:
|
|
1933
|
+
None.
|
|
1934
|
+
"""
|
|
1935
|
+
return self._executor.set_creation_type(tensor, creation_type)
|
|
1936
|
+
|
|
1805
1937
|
|
|
1806
1938
|
class _CellGraphExecutor:
|
|
1807
1939
|
"""
|
|
@@ -1878,13 +2010,6 @@ class _CellGraphExecutor:
|
|
|
1878
2010
|
else:
|
|
1879
2011
|
_set_dataset_mode_config('normal')
|
|
1880
2012
|
|
|
1881
|
-
@staticmethod
|
|
1882
|
-
def _use_vm_mode():
|
|
1883
|
-
enable_ge = context.get_context("enable_ge")
|
|
1884
|
-
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
|
1885
|
-
exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
|
|
1886
|
-
return not enable_ge or (enable_debug_runtime and exe_mode)
|
|
1887
|
-
|
|
1888
2013
|
def _build_data_graph(self, obj, phase):
|
|
1889
2014
|
self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
|
|
1890
2015
|
|
|
@@ -1916,7 +2041,12 @@ class _CellGraphExecutor:
|
|
|
1916
2041
|
obj.__parse_method__ = 'construct'
|
|
1917
2042
|
if not hasattr(obj, obj.__parse_method__):
|
|
1918
2043
|
raise AttributeError(
|
|
1919
|
-
'The class {}
|
|
2044
|
+
'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
2045
|
+
inner_func = inspect.unwrap(obj.construct)
|
|
2046
|
+
if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
|
|
2047
|
+
raise ValueError(
|
|
2048
|
+
"When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
|
|
2049
|
+
)
|
|
1920
2050
|
key_id = str(id(obj)) + str(obj.create_time)
|
|
1921
2051
|
args = get_auto_dynamic_shape_args(args, key_id)
|
|
1922
2052
|
|
|
@@ -1927,20 +2057,25 @@ class _CellGraphExecutor:
|
|
|
1927
2057
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1928
2058
|
|
|
1929
2059
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
|
|
2060
|
+
key = str(key)
|
|
1933
2061
|
|
|
1934
2062
|
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1935
2063
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1936
2064
|
if parameter_ids != "":
|
|
1937
|
-
|
|
2065
|
+
key += '.' + parameter_ids
|
|
2066
|
+
|
|
2067
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
2068
|
+
key += "." + str(_hook_version())
|
|
2069
|
+
|
|
2070
|
+
obj.arguments_key = key
|
|
2071
|
+
|
|
1938
2072
|
raw_phase = phase
|
|
1939
|
-
|
|
2073
|
+
|
|
2074
|
+
phase = _real_phase(phase, obj)
|
|
1940
2075
|
obj.phase_cache[raw_phase] = phase
|
|
1941
2076
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
1942
2077
|
obj.current_phase = phase
|
|
1943
|
-
if phase in obj.compile_cache and self.has_compiled(phase)
|
|
2078
|
+
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
1944
2079
|
logger.debug("%r graph has existed.", phase)
|
|
1945
2080
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1946
2081
|
# generated in generate_arguments_key.
|
|
@@ -1948,7 +2083,7 @@ class _CellGraphExecutor:
|
|
|
1948
2083
|
_clear_auto_parallel_context(obj)
|
|
1949
2084
|
return phase, False
|
|
1950
2085
|
|
|
1951
|
-
full_function_name = obj.__class__.__name__ + '.' + str(obj.
|
|
2086
|
+
full_function_name = obj.__class__.__name__ + '.' + str(obj.total_instance_count) + '.' + str(id(type(obj)))
|
|
1952
2087
|
echo_function_name = obj.__class__.__name__
|
|
1953
2088
|
_check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
|
|
1954
2089
|
|
|
@@ -1958,17 +2093,14 @@ class _CellGraphExecutor:
|
|
|
1958
2093
|
self._set_compile_cache_dep_files(phase)
|
|
1959
2094
|
|
|
1960
2095
|
self._graph_executor.set_weights_values(obj.parameters_dict())
|
|
1961
|
-
if jit_config_dict:
|
|
1962
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1963
|
-
else:
|
|
2096
|
+
if not jit_config_dict:
|
|
1964
2097
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1965
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1966
2098
|
gc.collect()
|
|
1967
|
-
result = self._graph_executor.compile(
|
|
2099
|
+
result = self._graph_executor.compile(
|
|
2100
|
+
obj, args, kwargs, phase, jit_config_dict)
|
|
1968
2101
|
obj.compile_cache.add(phase)
|
|
1969
2102
|
if not result:
|
|
1970
2103
|
raise RuntimeError("Executor compile failed.")
|
|
1971
|
-
set_parameter_hook_updated(False)
|
|
1972
2104
|
graph = self._graph_executor.get_func_graph(phase)
|
|
1973
2105
|
|
|
1974
2106
|
if graph is None:
|
|
@@ -1993,15 +2125,15 @@ class _CellGraphExecutor:
|
|
|
1993
2125
|
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
|
1994
2126
|
|
|
1995
2127
|
def _get_shard_strategy(self, obj):
|
|
1996
|
-
real_phase = obj.phase
|
|
2128
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
1997
2129
|
return self._graph_executor.get_strategy(real_phase)
|
|
1998
2130
|
|
|
1999
2131
|
def _get_num_parallel_ops(self, obj):
|
|
2000
|
-
real_phase = obj.phase
|
|
2132
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
2001
2133
|
return self._graph_executor.get_num_parallel_ops(real_phase)
|
|
2002
2134
|
|
|
2003
2135
|
def _get_allreduce_fusion(self, obj):
|
|
2004
|
-
real_phase = obj.phase
|
|
2136
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
2005
2137
|
return self._graph_executor.get_allreduce_fusion(real_phase)
|
|
2006
2138
|
|
|
2007
2139
|
def __call__(self, obj, *args, phase='predict'):
|
|
@@ -2053,10 +2185,10 @@ class _CellGraphExecutor:
|
|
|
2053
2185
|
Tensor/Tuple, return execute result.
|
|
2054
2186
|
"""
|
|
2055
2187
|
if phase == 'save':
|
|
2056
|
-
exe_phase = phase
|
|
2188
|
+
exe_phase = _real_phase(phase, obj)
|
|
2057
2189
|
return self._graph_executor((), exe_phase)
|
|
2058
2190
|
|
|
2059
|
-
phase_real = phase
|
|
2191
|
+
phase_real = _real_phase(phase, obj)
|
|
2060
2192
|
if self.has_compiled(phase_real):
|
|
2061
2193
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
2062
2194
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
@@ -2083,7 +2215,7 @@ class _CellGraphExecutor:
|
|
|
2083
2215
|
|
|
2084
2216
|
def get_optimize_graph_proto(self, obj):
|
|
2085
2217
|
"""Return optimize graph binary proto."""
|
|
2086
|
-
exec_id = obj.phase
|
|
2218
|
+
exec_id = _real_phase(obj.phase, obj)
|
|
2087
2219
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2088
2220
|
return None
|
|
2089
2221
|
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
|
|
@@ -2165,5 +2297,3 @@ def flops_collection(phase='train'):
|
|
|
2165
2297
|
|
|
2166
2298
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2167
2299
|
_pynative_executor = _PyNativeExecutor()
|
|
2168
|
-
|
|
2169
|
-
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|