mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +65 -84
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +58 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +178 -53
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +377 -203
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +117 -131
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +67 -55
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +70 -24
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +429 -23
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +140 -104
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +491 -623
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +4 -6
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
- mindspore/ops/auto_generate/gen_extend_func.py +27 -145
- mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
- mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +19 -102
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +77 -572
- mindspore/ops/function/nn_func.py +46 -94
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +4 -4
- mindspore/ops/functional_overload.py +594 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +5 -51
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +303 -177
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +8 -40
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +63 -15
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +16 -23
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +35 -14
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +13 -7
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +12 -12
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +10 -25
- mindspore/parallel/transform_safetensors.py +469 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +31 -32
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +17 -7
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +98 -21
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +133 -69
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|
|
21
|
+
|
|
20
22
|
import gc
|
|
21
23
|
import types
|
|
22
24
|
import sys
|
|
@@ -42,23 +44,25 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
|
42
44
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
43
45
|
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
44
46
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
45
|
-
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx,
|
|
47
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
|
|
46
48
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
47
49
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
50
|
_is_parallel_mode
|
|
49
51
|
from mindspore import _checkparam as Validator
|
|
50
52
|
from mindspore._checkparam import is_stub_tensor
|
|
51
|
-
from mindspore.common._utils import is_shape_unknown
|
|
53
|
+
from mindspore.common._utils import is_shape_unknown, get_func
|
|
52
54
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
53
|
-
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args,
|
|
54
|
-
|
|
55
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
|
|
56
|
+
update_auto_dynamic_shape_phase
|
|
57
|
+
from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
|
|
55
58
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
-
from mindspore.common.parameter import Parameter
|
|
59
|
+
from mindspore.common.parameter import Parameter
|
|
60
|
+
from mindspore.common.hook_handle import _hook_version
|
|
57
61
|
from mindspore.common.jit_context import jit_context
|
|
58
62
|
from mindspore.common.jit_trace import _jit_trace
|
|
59
63
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
60
64
|
|
|
61
|
-
# Store
|
|
65
|
+
# Store jit class compiled pipeline cache.
|
|
62
66
|
ms_compile_cache = set()
|
|
63
67
|
# Store cell compiled pipeline cache.
|
|
64
68
|
cells_compile_cache = {}
|
|
@@ -72,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
|
|
|
72
76
|
TOTAL_ARG_LEN = "total_arg_length"
|
|
73
77
|
|
|
74
78
|
|
|
79
|
+
def _real_phase(phase, obj):
|
|
80
|
+
real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
81
|
+
return real_phase
|
|
82
|
+
|
|
83
|
+
|
|
75
84
|
def _check_recompile_args(compile_args, kwargs):
|
|
76
85
|
"""Check recompile of graph"""
|
|
77
86
|
|
|
@@ -134,8 +143,6 @@ def _convert_python_data(data):
|
|
|
134
143
|
"""
|
|
135
144
|
if isinstance(data, PythonTensor):
|
|
136
145
|
return data
|
|
137
|
-
if isinstance(data, StubNode):
|
|
138
|
-
return ms.common._stub_tensor._convert_stub(data)
|
|
139
146
|
if data.__class__ is tuple:
|
|
140
147
|
# Handle namedtuple since its type is tuple.
|
|
141
148
|
if hasattr(data, "_fields"):
|
|
@@ -278,13 +285,13 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
278
285
|
module = importlib.util.module_from_spec(module_spec)
|
|
279
286
|
if hasattr(module, '__file__'):
|
|
280
287
|
dep_file_path = module.__file__
|
|
288
|
+
# Exclude the installed modules.
|
|
289
|
+
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
290
|
+
logger.debug(f"dependent file path: {dep_file_path}")
|
|
291
|
+
compile_cache_dep_files.append(dep_file_path)
|
|
292
|
+
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
281
293
|
else:
|
|
282
294
|
continue
|
|
283
|
-
# Exclude the installed modules.
|
|
284
|
-
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
285
|
-
logger.debug(f"dependent file path: {dep_file_path}")
|
|
286
|
-
compile_cache_dep_files.append(dep_file_path)
|
|
287
|
-
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
288
295
|
|
|
289
296
|
|
|
290
297
|
def _get_compile_cache_dep_files():
|
|
@@ -342,7 +349,7 @@ def _get_parameter_layout():
|
|
|
342
349
|
return layout
|
|
343
350
|
|
|
344
351
|
|
|
345
|
-
def _handle_arg(obj, arg,
|
|
352
|
+
def _handle_arg(obj, arg, has_mutable_arg):
|
|
346
353
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
347
354
|
from mindspore._extends.parse import compile_config
|
|
348
355
|
if isinstance(arg, PythonTensor):
|
|
@@ -352,7 +359,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
352
359
|
return arg
|
|
353
360
|
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
|
354
361
|
return arg
|
|
355
|
-
elif
|
|
362
|
+
elif has_mutable_arg:
|
|
356
363
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
357
364
|
if isinstance(arg, list) and not arg:
|
|
358
365
|
return None
|
|
@@ -366,7 +373,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
366
373
|
return None
|
|
367
374
|
|
|
368
375
|
|
|
369
|
-
def _handle_arg_predict(obj, arg,
|
|
376
|
+
def _handle_arg_predict(obj, arg, has_mutable_arg):
|
|
370
377
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
371
378
|
if arg is None:
|
|
372
379
|
return None
|
|
@@ -375,8 +382,7 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
375
382
|
return None
|
|
376
383
|
|
|
377
384
|
if isinstance(arg, (list, tuple)):
|
|
378
|
-
if
|
|
379
|
-
getattr(compile_arg, "__ms_mutable__"):
|
|
385
|
+
if has_mutable_arg:
|
|
380
386
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
381
387
|
if isinstance(arg, list) and not arg:
|
|
382
388
|
return None
|
|
@@ -388,35 +394,30 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
388
394
|
return arg
|
|
389
395
|
|
|
390
396
|
|
|
391
|
-
def _get_args_for_run(obj, args, kwargs,
|
|
397
|
+
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
|
|
392
398
|
"""Get the actual input args and kwargs for runtime."""
|
|
393
399
|
new_args = []
|
|
394
|
-
|
|
395
|
-
|
|
400
|
+
fn = _handle_arg_predict if is_predict else _handle_arg
|
|
401
|
+
for arg, has_mutable_arg in zip(args, has_mutable_args_list):
|
|
402
|
+
new_arg = fn(obj, arg, has_mutable_arg)
|
|
396
403
|
if new_arg is not None:
|
|
397
404
|
new_args.append(new_arg)
|
|
398
405
|
|
|
399
406
|
for _, value in kwargs.items():
|
|
400
|
-
new_value =
|
|
407
|
+
new_value = fn(obj, value, None)
|
|
401
408
|
if new_value is not None:
|
|
402
409
|
new_args.append(new_value)
|
|
403
410
|
|
|
404
411
|
return new_args
|
|
405
412
|
|
|
406
413
|
|
|
407
|
-
def
|
|
408
|
-
"""Get
|
|
414
|
+
def _get_mutable_flags(compile_args):
|
|
415
|
+
"""Get a list of booleans indicating whether each argument is marked as mutable"""
|
|
409
416
|
new_args = []
|
|
410
|
-
for
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
for _, value in kwargs.items():
|
|
416
|
-
new_value = _handle_arg_predict(obj, value, None)
|
|
417
|
-
if new_value is not None:
|
|
418
|
-
new_args.append(new_value)
|
|
419
|
-
|
|
417
|
+
for compile_arg in compile_args:
|
|
418
|
+
has_mutable_arg = compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
|
|
419
|
+
getattr(compile_arg, "__ms_mutable__")
|
|
420
|
+
new_args.append(has_mutable_arg)
|
|
420
421
|
return new_args
|
|
421
422
|
|
|
422
423
|
|
|
@@ -544,10 +545,12 @@ def _get_parameter_ids(args, kwargs):
|
|
|
544
545
|
parameter_ids += str(id(value))
|
|
545
546
|
return parameter_ids
|
|
546
547
|
|
|
548
|
+
|
|
547
549
|
def _get_tensor_hook_key(tensor):
|
|
548
550
|
"""Get the hook key of Tensor/Parameter"""
|
|
549
551
|
return ".".join(map(str, map(id, tensor.hooks())))
|
|
550
552
|
|
|
553
|
+
|
|
551
554
|
def _get_hook_key(*args, **kwargs):
|
|
552
555
|
"""Get the hook key of Tensors/Parameters"""
|
|
553
556
|
hook_key = ""
|
|
@@ -586,13 +589,16 @@ class _JitExecutor:
|
|
|
586
589
|
The result of pipeline running in graph mode.
|
|
587
590
|
"""
|
|
588
591
|
|
|
589
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0
|
|
592
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0,
|
|
593
|
+
cell_cache_key_extend=''):
|
|
590
594
|
init_pipeline()
|
|
591
595
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
592
596
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
593
597
|
|
|
594
598
|
self.fn = fn
|
|
595
599
|
self.input_signature = input_signature
|
|
600
|
+
self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
|
|
601
|
+
self.enable_jit_dynamic = self.dynamic_args_shapes is not None
|
|
596
602
|
self.obj = None
|
|
597
603
|
if obj and hasattr(obj, fn.__name__):
|
|
598
604
|
self.obj = obj
|
|
@@ -606,9 +612,55 @@ class _JitExecutor:
|
|
|
606
612
|
self._compile_args = None
|
|
607
613
|
self._enable_auto_dynamic = dynamic == 1
|
|
608
614
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
615
|
+
self._cell_cache_key_extend = cell_cache_key_extend
|
|
616
|
+
|
|
617
|
+
def _predict(self, *args, **kwargs):
|
|
618
|
+
"""Dedicated routine for predict."""
|
|
619
|
+
if not hasattr(self.obj, "phase"):
|
|
620
|
+
return False, None
|
|
621
|
+
|
|
622
|
+
predict_vailid_phase = {"prefill", 'increment'}
|
|
623
|
+
predict_phase = self.obj.phase
|
|
624
|
+
if predict_phase not in predict_vailid_phase:
|
|
625
|
+
return False, None
|
|
626
|
+
|
|
627
|
+
args_list = args
|
|
628
|
+
if self.obj is not None:
|
|
629
|
+
args_list = args_list[1:]
|
|
630
|
+
|
|
631
|
+
if predict_phase not in self.obj.phase_cache:
|
|
632
|
+
try:
|
|
633
|
+
predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
634
|
+
except Exception as err:
|
|
635
|
+
_pynative_executor.clear_res()
|
|
636
|
+
raise err
|
|
637
|
+
else: # get compiled args to generate run args by _generate_run_args
|
|
638
|
+
compile_args = self._generate_compile_args(args_list)
|
|
639
|
+
key_id = self._get_key_id()
|
|
640
|
+
if self.input_signature is None:
|
|
641
|
+
compile_args = get_auto_dynamic_shape_args(
|
|
642
|
+
compile_args, key_id, self._enable_auto_dynamic
|
|
643
|
+
)
|
|
644
|
+
self._compile_args = compile_args
|
|
645
|
+
|
|
646
|
+
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
647
|
+
if self.jit_config_dict:
|
|
648
|
+
jit_config_dict = self.jit_config_dict
|
|
649
|
+
else:
|
|
650
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
651
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
652
|
+
output = self._graph_executor(
|
|
653
|
+
tuple(new_inputs),
|
|
654
|
+
self.obj.phase_cache[self.obj.phase]
|
|
655
|
+
)
|
|
656
|
+
res = _convert_python_data(output)
|
|
657
|
+
return True, res
|
|
609
658
|
|
|
610
659
|
@_wrap_func
|
|
611
660
|
def __call__(self, *args, **kwargs):
|
|
661
|
+
predict, res = self._predict(*args, **kwargs)
|
|
662
|
+
if predict:
|
|
663
|
+
return res
|
|
612
664
|
if jit_context() and jit_context().is_nested():
|
|
613
665
|
return jit_context().run_graph("", None, *())
|
|
614
666
|
args_list = args
|
|
@@ -616,12 +668,9 @@ class _JitExecutor:
|
|
|
616
668
|
args_list = args_list[1:]
|
|
617
669
|
phase = ""
|
|
618
670
|
try:
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
_pynative_executor.set_jit_compile_status(False, phase)
|
|
623
|
-
else:
|
|
624
|
-
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
671
|
+
_pynative_executor.set_jit_compile_status(True, phase)
|
|
672
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
673
|
+
_pynative_executor.set_jit_compile_status(False, phase)
|
|
625
674
|
except Exception as err:
|
|
626
675
|
_pynative_executor.clear_res()
|
|
627
676
|
raise err
|
|
@@ -630,31 +679,27 @@ class _JitExecutor:
|
|
|
630
679
|
return None
|
|
631
680
|
|
|
632
681
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
633
|
-
if
|
|
634
|
-
|
|
682
|
+
if self.jit_config_dict:
|
|
683
|
+
jit_config_dict = self.jit_config_dict
|
|
635
684
|
else:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
685
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
686
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
687
|
+
output = _pynative_executor.grad_jit(*new_inputs)
|
|
688
|
+
if jit_context():
|
|
689
|
+
if is_stub_tensor(output):
|
|
690
|
+
output = output.stub_sync()
|
|
691
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
642
692
|
return output
|
|
643
693
|
|
|
644
694
|
def compile(self, method_name, *args, **kwargs):
|
|
645
695
|
"""Returns pipeline for the given args."""
|
|
646
|
-
# Check whether hook function registered on Cell object.
|
|
647
|
-
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
|
648
|
-
if self.obj._hook_fn_registered():
|
|
649
|
-
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
|
|
650
|
-
f"If you want to use hook function, please use context.set_context to set "
|
|
651
|
-
f"pynative mode and remove 'jit' decorator.")
|
|
652
696
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
653
697
|
compile_args = self._generate_compile_args(args)
|
|
654
698
|
key_id = self._get_key_id()
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
699
|
+
if self.input_signature is None:
|
|
700
|
+
compile_args = get_auto_dynamic_shape_args(
|
|
701
|
+
compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
|
|
702
|
+
)
|
|
658
703
|
|
|
659
704
|
# Add mutable for compile_args for two scene:
|
|
660
705
|
# 1) Origin args is mutable.
|
|
@@ -674,7 +719,7 @@ class _JitExecutor:
|
|
|
674
719
|
f'`{self.fn.__module__}`')
|
|
675
720
|
self.obj.__parse_method__ = method_name
|
|
676
721
|
if isinstance(self.obj, ms.nn.Cell):
|
|
677
|
-
generate_name = generate_name + '.' + str(self.obj.create_time)
|
|
722
|
+
generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
|
|
678
723
|
create_time = str(self.obj.create_time)
|
|
679
724
|
else:
|
|
680
725
|
generate_name = generate_name + '.' + str(self._create_time)
|
|
@@ -694,18 +739,23 @@ class _JitExecutor:
|
|
|
694
739
|
|
|
695
740
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
696
741
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
742
|
+
key = str(key)
|
|
697
743
|
|
|
698
744
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
699
745
|
if parameter_ids != "":
|
|
700
|
-
key
|
|
746
|
+
key += '.' + parameter_ids
|
|
747
|
+
|
|
748
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
749
|
+
key += "." + str(_hook_version())
|
|
701
750
|
|
|
702
|
-
|
|
751
|
+
phase = generate_name + '.' + key
|
|
703
752
|
|
|
704
|
-
|
|
753
|
+
if self.input_signature is None:
|
|
754
|
+
update_auto_dynamic_shape_phase(compile_args, key_id, phase)
|
|
705
755
|
|
|
706
|
-
|
|
756
|
+
phase = phase + self._cell_cache_key_extend
|
|
707
757
|
|
|
708
|
-
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase)
|
|
758
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
|
|
709
759
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
710
760
|
# generated in generate_arguments_key.
|
|
711
761
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -716,31 +766,26 @@ class _JitExecutor:
|
|
|
716
766
|
# If enable compile cache, get the dependency files list and set to graph executor.
|
|
717
767
|
self._set_compile_cache_dep_files()
|
|
718
768
|
if self.jit_config_dict:
|
|
719
|
-
self.
|
|
769
|
+
jit_config_dict = self.jit_config_dict
|
|
720
770
|
else:
|
|
721
771
|
jit_config_dict = JitConfig().jit_config_dict
|
|
722
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
723
772
|
|
|
724
773
|
if self.obj is None:
|
|
725
774
|
# Set an attribute to fn as an identifier.
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
setattr(self.fn, "__jit_function__", True)
|
|
730
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
|
|
731
|
-
if isinstance(self.fn, types.MethodType):
|
|
732
|
-
delattr(self.fn.__func__, "__jit_function__")
|
|
733
|
-
else:
|
|
734
|
-
delattr(self.fn, "__jit_function__")
|
|
775
|
+
setattr(get_func(self.fn), "__jit_function__", True)
|
|
776
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
777
|
+
delattr(get_func(self.fn), "__jit_function__")
|
|
735
778
|
else:
|
|
736
779
|
if isinstance(self.obj, ms.nn.Cell):
|
|
737
780
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
738
|
-
is_compile = self._graph_executor.compile(
|
|
781
|
+
is_compile = self._graph_executor.compile(
|
|
782
|
+
self.obj, compile_args, kwargs, phase, jit_config_dict)
|
|
739
783
|
|
|
740
784
|
if not is_compile:
|
|
741
785
|
raise RuntimeError("Executor compile failed.")
|
|
742
|
-
set_parameter_hook_updated(False)
|
|
743
786
|
ms_compile_cache.add(phase)
|
|
787
|
+
if hasattr(self.obj, "phase"):
|
|
788
|
+
self.obj.phase_cache[self.obj.phase] = phase
|
|
744
789
|
|
|
745
790
|
return phase
|
|
746
791
|
|
|
@@ -785,41 +830,70 @@ class _JitExecutor:
|
|
|
785
830
|
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
786
831
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
787
832
|
|
|
833
|
+
def _generate_compile_args_by_enable_dynamic(self, args_list):
|
|
834
|
+
"""Generate compile args by enable_dynamic."""
|
|
835
|
+
compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
|
|
836
|
+
compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
|
|
837
|
+
if self.obj is not None:
|
|
838
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
839
|
+
else:
|
|
840
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
841
|
+
logger.info(f"dynamic shape compile_args: {compile_args}")
|
|
842
|
+
return compile_args
|
|
843
|
+
|
|
844
|
+
def _generate_compile_args_by_set_inputs(self, args_list):
|
|
845
|
+
"""Generate compile args by set_inputs."""
|
|
846
|
+
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
847
|
+
if len(compile_args) != len(args_list):
|
|
848
|
+
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
849
|
+
f"dynamic shape tensors: {len(compile_args)}.")
|
|
850
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
|
|
851
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
852
|
+
return compile_args
|
|
853
|
+
|
|
854
|
+
def _generate_compile_args_by_input_signature(self, args_list):
|
|
855
|
+
"""Generate compile args by input_signature."""
|
|
856
|
+
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
857
|
+
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
858
|
+
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
859
|
+
if dyn_shape:
|
|
860
|
+
# Checkout whether the `sens` has been added to args_list.
|
|
861
|
+
if len(compile_args) == len(args_list) - 1:
|
|
862
|
+
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
863
|
+
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
864
|
+
f"be 'sens' and added it to compile args.")
|
|
865
|
+
compile_args.append(args_list[-1])
|
|
866
|
+
compile_args = tuple(compile_args)
|
|
867
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
868
|
+
if self.obj is not None:
|
|
869
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
870
|
+
else:
|
|
871
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
872
|
+
else:
|
|
873
|
+
if not verify_inputs_signature(compile_args, args_list):
|
|
874
|
+
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
875
|
+
return compile_args
|
|
876
|
+
|
|
877
|
+
def _check_set_inputs(self):
|
|
878
|
+
"""Check if the `set_inputs()` of Cell object has been set."""
|
|
879
|
+
return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
|
|
880
|
+
|
|
788
881
|
def _generate_compile_args(self, args_list):
|
|
789
882
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
790
|
-
# Case:
|
|
791
|
-
|
|
883
|
+
# Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
|
|
884
|
+
if self.enable_jit_dynamic and self._check_set_inputs():
|
|
885
|
+
raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
|
|
886
|
+
# Case: The `enable_dynamic` is provided.
|
|
887
|
+
if self.enable_jit_dynamic:
|
|
888
|
+
return self._generate_compile_args_by_enable_dynamic(args_list)
|
|
792
889
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
793
|
-
if self.
|
|
794
|
-
|
|
795
|
-
if len(compile_args) != len(args_list):
|
|
796
|
-
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
797
|
-
f"dynamic shape tensors: {len(compile_args)}.")
|
|
798
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
799
|
-
Validator.check_symbolic_shape(compile_args, args_list)
|
|
800
|
-
|
|
890
|
+
if self._check_set_inputs():
|
|
891
|
+
return self._generate_compile_args_by_set_inputs(args_list)
|
|
801
892
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
802
893
|
if self.input_signature is not None:
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
if dyn_shape:
|
|
807
|
-
# Checkout whether the `sens` has been added to args_list.
|
|
808
|
-
if len(compile_args) == len(args_list) - 1:
|
|
809
|
-
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
810
|
-
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
811
|
-
f"be 'sens' and added it to compile args.")
|
|
812
|
-
compile_args.append(args_list[-1])
|
|
813
|
-
compile_args = tuple(compile_args)
|
|
814
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
815
|
-
if self.obj is not None:
|
|
816
|
-
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
817
|
-
else:
|
|
818
|
-
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
819
|
-
else:
|
|
820
|
-
if not verify_inputs_signature(compile_args, args_list):
|
|
821
|
-
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
822
|
-
return compile_args
|
|
894
|
+
return self._generate_compile_args_by_input_signature(args_list)
|
|
895
|
+
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
896
|
+
return _pynative_executor.get_dynamic_input(args_list)
|
|
823
897
|
|
|
824
898
|
def _generate_run_args(self, args_list, kwargs):
|
|
825
899
|
"""
|
|
@@ -832,7 +906,7 @@ class _JitExecutor:
|
|
|
832
906
|
Returns:
|
|
833
907
|
new_inputs, new input args, which are required for running.
|
|
834
908
|
"""
|
|
835
|
-
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
909
|
+
return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
|
|
836
910
|
|
|
837
911
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
838
912
|
"""Get graph proto from pipeline."""
|
|
@@ -993,6 +1067,67 @@ def _check_options(options, backend):
|
|
|
993
1067
|
_check_option_value(option, value)
|
|
994
1068
|
|
|
995
1069
|
|
|
1070
|
+
def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
1071
|
+
"""Return the wrapped function for ast mode jit."""
|
|
1072
|
+
def wrap_func(func):
|
|
1073
|
+
nonlocal hash_obj
|
|
1074
|
+
if hasattr(func, "construct"):
|
|
1075
|
+
if isinstance(func, ms.nn.Cell):
|
|
1076
|
+
# Bound the cell object to get the self arg.
|
|
1077
|
+
return types.MethodType(_jit_ast(
|
|
1078
|
+
hash_obj, dynamic, jit_config, func._jit_graph_name)(func.construct.__func__), func)
|
|
1079
|
+
if isinstance(func, type) and issubclass(func, ms.nn.Cell):
|
|
1080
|
+
func.construct = _jit_ast(
|
|
1081
|
+
hash_obj, dynamic, jit_config, '')(func.construct)
|
|
1082
|
+
return func
|
|
1083
|
+
|
|
1084
|
+
if isinstance(func, types.MethodType):
|
|
1085
|
+
return types.MethodType(_jit_ast(hash_obj, dynamic, jit_config, '')(func.__func__), func.__self__)
|
|
1086
|
+
|
|
1087
|
+
if not isinstance(func, types.FunctionType):
|
|
1088
|
+
logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
|
|
1089
|
+
return func
|
|
1090
|
+
|
|
1091
|
+
if hasattr(func, "__wrapped_by_jit__"):
|
|
1092
|
+
logger.warning(f"The func {func} should be wrapped by jit only once.")
|
|
1093
|
+
|
|
1094
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
1095
|
+
hash_obj = int(time.time() * 1e9)
|
|
1096
|
+
|
|
1097
|
+
@wraps(func)
|
|
1098
|
+
def staging_specialize(*args, **kwargs):
|
|
1099
|
+
if os.getenv("MS_JIT") == '0':
|
|
1100
|
+
return func(*args, **kwargs)
|
|
1101
|
+
|
|
1102
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1103
|
+
process_obj = None
|
|
1104
|
+
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1105
|
+
process_obj = args[0]
|
|
1106
|
+
# Handle auto mixed precision strategy.
|
|
1107
|
+
if not hasattr(func, "amp_strategy"):
|
|
1108
|
+
setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
|
|
1109
|
+
|
|
1110
|
+
jit_graph_name = ''
|
|
1111
|
+
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
1112
|
+
jit_graph_name = staging_specialize.__jit_graph_name__
|
|
1113
|
+
jit_executor = _JitExecutor(
|
|
1114
|
+
func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
1115
|
+
out = jit_executor(*args, **kwargs)
|
|
1116
|
+
if isinstance(process_obj, ms.nn.Cell):
|
|
1117
|
+
_clear_auto_parallel_context(process_obj)
|
|
1118
|
+
return out
|
|
1119
|
+
|
|
1120
|
+
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
1121
|
+
# `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
|
|
1122
|
+
# original `func`.
|
|
1123
|
+
staging_specialize.__signature__ = inspect.signature(func)
|
|
1124
|
+
setattr(staging_specialize, "__wrapped_by_jit__", True)
|
|
1125
|
+
setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
|
|
1126
|
+
return staging_specialize
|
|
1127
|
+
|
|
1128
|
+
return wrap_func
|
|
1129
|
+
|
|
1130
|
+
|
|
996
1131
|
def jit(
|
|
997
1132
|
function: Optional[Callable] = None,
|
|
998
1133
|
*,
|
|
@@ -1015,45 +1150,45 @@ def jit(
|
|
|
1015
1150
|
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1016
1151
|
|
|
1017
1152
|
Args:
|
|
1018
|
-
function (
|
|
1153
|
+
function (Callable, optional): The Python function or Cell that will be run as a graph. Default: ``None``.
|
|
1019
1154
|
|
|
1020
1155
|
Keyword Args:
|
|
1021
1156
|
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1022
|
-
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1157
|
+
should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
|
|
1023
1158
|
|
|
1024
|
-
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
change and/or deletion.
|
|
1029
|
-
- `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1030
|
-
subject to change and/or deletion.
|
|
1159
|
+
- ast: Parse Python ast to build graph.
|
|
1160
|
+
- bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
|
|
1161
|
+
that is subject to change and/or deletion.
|
|
1162
|
+
- trace: Trace the execution of Python code to build graph. This is an experimental prototype
|
|
1163
|
+
that is subject to change and/or deletion.
|
|
1031
1164
|
|
|
1032
1165
|
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1033
|
-
with
|
|
1166
|
+
with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
|
|
1034
1167
|
|
|
1035
|
-
-
|
|
1036
|
-
-
|
|
1168
|
+
- O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1169
|
+
- O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1037
1170
|
level is experimental and is being improved.
|
|
1038
1171
|
|
|
1039
1172
|
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1040
1173
|
is as follows:
|
|
1041
1174
|
|
|
1042
|
-
-
|
|
1043
|
-
-
|
|
1175
|
+
- 0: Do not perform dynamic shape compilation.
|
|
1176
|
+
- 1: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1044
1177
|
|
|
1045
1178
|
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1046
1179
|
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1047
1180
|
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1048
|
-
not supported), then it will raise an exception. This currently only applies when capture_mode is ast
|
|
1049
|
-
Default: ``False``.
|
|
1181
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
1182
|
+
or ``bytecode``. Default: ``False``.
|
|
1050
1183
|
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1051
|
-
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including
|
|
1052
|
-
A2 training series products by default.
|
|
1184
|
+
use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
|
|
1185
|
+
Atlas A2 training series products by default.
|
|
1053
1186
|
|
|
1054
|
-
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1187
|
+
- ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
|
|
1188
|
+
optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
|
|
1189
|
+
- GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
1190
|
+
for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
|
|
1191
|
+
and can be executed only on Ascend hardware.
|
|
1057
1192
|
|
|
1058
1193
|
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1059
1194
|
|
|
@@ -1076,11 +1211,11 @@ def jit(
|
|
|
1076
1211
|
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1077
1212
|
Default: ``False`` .
|
|
1078
1213
|
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1079
|
-
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1214
|
+
methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
|
|
1080
1215
|
|
|
1081
|
-
-
|
|
1216
|
+
- bfs: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1082
1217
|
performance.
|
|
1083
|
-
-
|
|
1218
|
+
- dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1084
1219
|
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1085
1220
|
other execution orders run out of memory (OOM).
|
|
1086
1221
|
|
|
@@ -1091,11 +1226,11 @@ def jit(
|
|
|
1091
1226
|
- global (dict): Set global options.
|
|
1092
1227
|
- session (dict): Set session options.
|
|
1093
1228
|
|
|
1094
|
-
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1229
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
|
|
1095
1230
|
the inference mode is disabled. The range is as follows:
|
|
1096
1231
|
|
|
1097
|
-
-
|
|
1098
|
-
-
|
|
1232
|
+
- on: Enable inference mode, get better infer performance.
|
|
1233
|
+
- off: Disable inference mode, use forward for inference. The performance is poor.
|
|
1099
1234
|
|
|
1100
1235
|
Returns:
|
|
1101
1236
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -1114,29 +1249,84 @@ def jit(
|
|
|
1114
1249
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1115
1250
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1116
1251
|
...
|
|
1117
|
-
>>> #
|
|
1252
|
+
>>> # Create a callable MindSpore graph by calling jit.
|
|
1118
1253
|
>>> def tensor_add(x, y):
|
|
1119
1254
|
... z = x + y
|
|
1120
1255
|
... return z
|
|
1121
1256
|
...
|
|
1122
1257
|
>>> tensor_add_graph = jit(function=tensor_add)
|
|
1123
1258
|
>>> out = tensor_add_graph(x, y)
|
|
1259
|
+
>>> print(out)
|
|
1260
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1261
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1262
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1263
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1124
1264
|
...
|
|
1125
|
-
>>> #
|
|
1265
|
+
>>> # Create a callable MindSpore graph through decorator @jit.
|
|
1126
1266
|
>>> @jit
|
|
1127
1267
|
... def tensor_add_with_dec(x, y):
|
|
1128
1268
|
... z = x + y
|
|
1129
1269
|
... return z
|
|
1130
1270
|
...
|
|
1131
1271
|
>>> out = tensor_add_with_dec(x, y)
|
|
1272
|
+
>>> print(out)
|
|
1273
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1274
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1275
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1276
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1132
1277
|
...
|
|
1133
|
-
>>> #
|
|
1278
|
+
>>> # Create a callable MindSpore graph and capture the entire function into the graph.
|
|
1134
1279
|
>>> @jit(fullgraph=True)
|
|
1135
1280
|
... def tensor_add_fullgraph(x, y):
|
|
1136
1281
|
... z = x + y
|
|
1137
1282
|
... return z
|
|
1138
1283
|
...
|
|
1139
1284
|
>>> out = tensor_add_fullgraph(x, y)
|
|
1285
|
+
>>> print(out)
|
|
1286
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1287
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1288
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1289
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1290
|
+
...
|
|
1291
|
+
>>> # Create a callable MindSpore graph by trace mode.
|
|
1292
|
+
>>> @jit(capture_mode="trace")
|
|
1293
|
+
... def tensor_add_by_trace(x, y):
|
|
1294
|
+
... z = x + y
|
|
1295
|
+
... return z
|
|
1296
|
+
...
|
|
1297
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1298
|
+
>>> print(out)
|
|
1299
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1300
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1301
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1302
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1303
|
+
...
|
|
1304
|
+
>>> # Create a callable MindSpore graph with ms_backend and jit_level="O1".
|
|
1305
|
+
>>> @jit(backend="ms_backend", jit_level="O1")
|
|
1306
|
+
... def tensor_add_by_trace(x, y):
|
|
1307
|
+
... z = x + y
|
|
1308
|
+
... return z
|
|
1309
|
+
...
|
|
1310
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1311
|
+
>>> print(out)
|
|
1312
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1313
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1314
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1315
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1316
|
+
...
|
|
1317
|
+
>>> # Create a callable MindSpore graph with GE backend and some ge options on Ascend.
|
|
1318
|
+
>>> @jit(backend="GE", ge_options={"global": {"ge.opSelectImplmode": "high_precision"}})
|
|
1319
|
+
... def tensor_add_by_trace(x, y):
|
|
1320
|
+
... z = x + y
|
|
1321
|
+
... return z
|
|
1322
|
+
...
|
|
1323
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1324
|
+
>>> print(out)
|
|
1325
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1326
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1327
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1328
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1329
|
+
...
|
|
1140
1330
|
"""
|
|
1141
1331
|
|
|
1142
1332
|
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
@@ -1155,39 +1345,12 @@ def jit(
|
|
|
1155
1345
|
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1156
1346
|
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1157
1347
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def staging_specialize(*args, **kwargs):
|
|
1165
|
-
if os.getenv("MS_JIT") == '0':
|
|
1166
|
-
return func(*args, **kwargs)
|
|
1167
|
-
|
|
1168
|
-
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1169
|
-
process_obj = None
|
|
1170
|
-
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1171
|
-
process_obj = args[0]
|
|
1172
|
-
# Handle auto mixed precision strategy.
|
|
1173
|
-
if not hasattr(func, "amp_strategy"):
|
|
1174
|
-
if isinstance(func, types.MethodType):
|
|
1175
|
-
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1176
|
-
else:
|
|
1177
|
-
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1178
|
-
|
|
1179
|
-
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1180
|
-
out = ms_function_executor(*args, **kwargs)
|
|
1181
|
-
return out
|
|
1182
|
-
|
|
1183
|
-
return staging_specialize
|
|
1184
|
-
|
|
1185
|
-
if capture_mode == "bytecode":
|
|
1186
|
-
wrap_func = PIJitCaptureContext(jit_config)
|
|
1187
|
-
elif capture_mode == "trace":
|
|
1188
|
-
if function is not None:
|
|
1189
|
-
return _jit_trace(function)
|
|
1190
|
-
return _jit_trace
|
|
1348
|
+
if capture_mode == "ast":
|
|
1349
|
+
wrap_func = _jit_ast(hash_obj, dynamic, jit_config, '')
|
|
1350
|
+
elif capture_mode == "bytecode":
|
|
1351
|
+
wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
|
|
1352
|
+
else:
|
|
1353
|
+
wrap_func = _jit_trace()
|
|
1191
1354
|
|
|
1192
1355
|
if function is not None:
|
|
1193
1356
|
return wrap_func(function)
|
|
@@ -1503,7 +1666,7 @@ class _PyNativeExecutor:
|
|
|
1503
1666
|
"""
|
|
1504
1667
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1505
1668
|
|
|
1506
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1669
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
|
|
1507
1670
|
"""
|
|
1508
1671
|
Whether the forward graph need to construct.
|
|
1509
1672
|
|
|
@@ -1516,7 +1679,7 @@ class _PyNativeExecutor:
|
|
|
1516
1679
|
Return:
|
|
1517
1680
|
bool, specifies whether the forward graph needs to construct.
|
|
1518
1681
|
"""
|
|
1519
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1682
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, **kwargs)
|
|
1520
1683
|
|
|
1521
1684
|
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1522
1685
|
"""
|
|
@@ -1758,6 +1921,19 @@ class _PyNativeExecutor:
|
|
|
1758
1921
|
"""
|
|
1759
1922
|
return self._executor.constant_folding(*args)
|
|
1760
1923
|
|
|
1924
|
+
def set_creation_type(self, tensor, creation_type):
|
|
1925
|
+
"""
|
|
1926
|
+
Set tensor's view creation type
|
|
1927
|
+
|
|
1928
|
+
Args:
|
|
1929
|
+
tensor (Tensor): input tensor.
|
|
1930
|
+
creation_type (CreationType): The type of view tensor when it is created.
|
|
1931
|
+
|
|
1932
|
+
Return:
|
|
1933
|
+
None.
|
|
1934
|
+
"""
|
|
1935
|
+
return self._executor.set_creation_type(tensor, creation_type)
|
|
1936
|
+
|
|
1761
1937
|
|
|
1762
1938
|
class _CellGraphExecutor:
|
|
1763
1939
|
"""
|
|
@@ -1834,13 +2010,6 @@ class _CellGraphExecutor:
|
|
|
1834
2010
|
else:
|
|
1835
2011
|
_set_dataset_mode_config('normal')
|
|
1836
2012
|
|
|
1837
|
-
@staticmethod
|
|
1838
|
-
def _use_vm_mode():
|
|
1839
|
-
enable_ge = context.get_context("enable_ge")
|
|
1840
|
-
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
|
1841
|
-
exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
|
|
1842
|
-
return not enable_ge or (enable_debug_runtime and exe_mode)
|
|
1843
|
-
|
|
1844
2013
|
def _build_data_graph(self, obj, phase):
|
|
1845
2014
|
self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
|
|
1846
2015
|
|
|
@@ -1872,7 +2041,12 @@ class _CellGraphExecutor:
|
|
|
1872
2041
|
obj.__parse_method__ = 'construct'
|
|
1873
2042
|
if not hasattr(obj, obj.__parse_method__):
|
|
1874
2043
|
raise AttributeError(
|
|
1875
|
-
'The class {}
|
|
2044
|
+
'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
2045
|
+
inner_func = inspect.unwrap(obj.construct)
|
|
2046
|
+
if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
|
|
2047
|
+
raise ValueError(
|
|
2048
|
+
"When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
|
|
2049
|
+
)
|
|
1876
2050
|
key_id = str(id(obj)) + str(obj.create_time)
|
|
1877
2051
|
args = get_auto_dynamic_shape_args(args, key_id)
|
|
1878
2052
|
|
|
@@ -1883,20 +2057,25 @@ class _CellGraphExecutor:
|
|
|
1883
2057
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1884
2058
|
|
|
1885
2059
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
|
|
2060
|
+
key = str(key)
|
|
1889
2061
|
|
|
1890
2062
|
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1891
2063
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1892
2064
|
if parameter_ids != "":
|
|
1893
|
-
|
|
2065
|
+
key += '.' + parameter_ids
|
|
2066
|
+
|
|
2067
|
+
key += "." + _get_hook_key(*args, **kwargs)
|
|
2068
|
+
key += "." + str(_hook_version())
|
|
2069
|
+
|
|
2070
|
+
obj.arguments_key = key
|
|
2071
|
+
|
|
1894
2072
|
raw_phase = phase
|
|
1895
|
-
|
|
2073
|
+
|
|
2074
|
+
phase = _real_phase(phase, obj)
|
|
1896
2075
|
obj.phase_cache[raw_phase] = phase
|
|
1897
2076
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
1898
2077
|
obj.current_phase = phase
|
|
1899
|
-
if phase in obj.compile_cache and self.has_compiled(phase)
|
|
2078
|
+
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
1900
2079
|
logger.debug("%r graph has existed.", phase)
|
|
1901
2080
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1902
2081
|
# generated in generate_arguments_key.
|
|
@@ -1904,7 +2083,7 @@ class _CellGraphExecutor:
|
|
|
1904
2083
|
_clear_auto_parallel_context(obj)
|
|
1905
2084
|
return phase, False
|
|
1906
2085
|
|
|
1907
|
-
full_function_name = obj.__class__.__name__ + '.' + str(obj.
|
|
2086
|
+
full_function_name = obj.__class__.__name__ + '.' + str(obj.total_instance_count) + '.' + str(id(type(obj)))
|
|
1908
2087
|
echo_function_name = obj.__class__.__name__
|
|
1909
2088
|
_check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
|
|
1910
2089
|
|
|
@@ -1914,17 +2093,14 @@ class _CellGraphExecutor:
|
|
|
1914
2093
|
self._set_compile_cache_dep_files(phase)
|
|
1915
2094
|
|
|
1916
2095
|
self._graph_executor.set_weights_values(obj.parameters_dict())
|
|
1917
|
-
if jit_config_dict:
|
|
1918
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1919
|
-
else:
|
|
2096
|
+
if not jit_config_dict:
|
|
1920
2097
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1921
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1922
2098
|
gc.collect()
|
|
1923
|
-
result = self._graph_executor.compile(
|
|
2099
|
+
result = self._graph_executor.compile(
|
|
2100
|
+
obj, args, kwargs, phase, jit_config_dict)
|
|
1924
2101
|
obj.compile_cache.add(phase)
|
|
1925
2102
|
if not result:
|
|
1926
2103
|
raise RuntimeError("Executor compile failed.")
|
|
1927
|
-
set_parameter_hook_updated(False)
|
|
1928
2104
|
graph = self._graph_executor.get_func_graph(phase)
|
|
1929
2105
|
|
|
1930
2106
|
if graph is None:
|
|
@@ -1949,15 +2125,15 @@ class _CellGraphExecutor:
|
|
|
1949
2125
|
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
|
1950
2126
|
|
|
1951
2127
|
def _get_shard_strategy(self, obj):
|
|
1952
|
-
real_phase = obj.phase
|
|
2128
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
1953
2129
|
return self._graph_executor.get_strategy(real_phase)
|
|
1954
2130
|
|
|
1955
2131
|
def _get_num_parallel_ops(self, obj):
|
|
1956
|
-
real_phase = obj.phase
|
|
2132
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
1957
2133
|
return self._graph_executor.get_num_parallel_ops(real_phase)
|
|
1958
2134
|
|
|
1959
2135
|
def _get_allreduce_fusion(self, obj):
|
|
1960
|
-
real_phase = obj.phase
|
|
2136
|
+
real_phase = _real_phase(obj.phase, obj)
|
|
1961
2137
|
return self._graph_executor.get_allreduce_fusion(real_phase)
|
|
1962
2138
|
|
|
1963
2139
|
def __call__(self, obj, *args, phase='predict'):
|
|
@@ -2009,10 +2185,10 @@ class _CellGraphExecutor:
|
|
|
2009
2185
|
Tensor/Tuple, return execute result.
|
|
2010
2186
|
"""
|
|
2011
2187
|
if phase == 'save':
|
|
2012
|
-
exe_phase = phase
|
|
2188
|
+
exe_phase = _real_phase(phase, obj)
|
|
2013
2189
|
return self._graph_executor((), exe_phase)
|
|
2014
2190
|
|
|
2015
|
-
phase_real = phase
|
|
2191
|
+
phase_real = _real_phase(phase, obj)
|
|
2016
2192
|
if self.has_compiled(phase_real):
|
|
2017
2193
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
2018
2194
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
@@ -2039,7 +2215,7 @@ class _CellGraphExecutor:
|
|
|
2039
2215
|
|
|
2040
2216
|
def get_optimize_graph_proto(self, obj):
|
|
2041
2217
|
"""Return optimize graph binary proto."""
|
|
2042
|
-
exec_id = obj.phase
|
|
2218
|
+
exec_id = _real_phase(obj.phase, obj)
|
|
2043
2219
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2044
2220
|
return None
|
|
2045
2221
|
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
|
|
@@ -2121,5 +2297,3 @@ def flops_collection(phase='train'):
|
|
|
2121
2297
|
|
|
2122
2298
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2123
2299
|
_pynative_executor = _PyNativeExecutor()
|
|
2124
|
-
|
|
2125
|
-
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|