mindspore 2.1.0__cp38-none-any.whl → 2.2.11__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -26,13 +26,14 @@ import inspect
|
|
|
26
26
|
import importlib
|
|
27
27
|
import hashlib
|
|
28
28
|
import contextlib
|
|
29
|
-
from collections import OrderedDict
|
|
29
|
+
from collections import OrderedDict, namedtuple
|
|
30
30
|
from functools import wraps
|
|
31
31
|
import numpy as np
|
|
32
32
|
import mindspore as ms
|
|
33
33
|
from mindspore import context
|
|
34
34
|
from mindspore import log as logger
|
|
35
35
|
from mindspore._extends.remote import kernel_build_server
|
|
36
|
+
from mindspore.common.jit_config import JitConfig
|
|
36
37
|
from mindspore.common.tensor import Tensor as PythonTensor
|
|
37
38
|
from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
38
39
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
@@ -48,11 +49,15 @@ from mindspore._checkparam import is_stub_tensor
|
|
|
48
49
|
from mindspore.common._utils import is_shape_unknown
|
|
49
50
|
from mindspore.common.mutable import mutable
|
|
50
51
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
52
|
+
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
53
|
+
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
51
54
|
|
|
52
55
|
# Store ms_function class compiled pipeline cache.
|
|
53
56
|
ms_compile_cache = set()
|
|
54
57
|
# Store cell compiled pipeline cache.
|
|
55
58
|
cells_compile_cache = {}
|
|
59
|
+
# Store function compiled times information.
|
|
60
|
+
function_phases = dict()
|
|
56
61
|
|
|
57
62
|
BROADCAST_PHASE = "_broadcast_"
|
|
58
63
|
_PYNATIVE_PARALLEL_FUNC_NAME = "after_shard"
|
|
@@ -79,6 +84,12 @@ def _convert_python_data(data):
|
|
|
79
84
|
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
80
85
|
return PythonRowTensor(row_tensor=data)
|
|
81
86
|
if isinstance(data, tuple):
|
|
87
|
+
# Handle namedtuple since its type is tuple.
|
|
88
|
+
if hasattr(data, "_fields"):
|
|
89
|
+
type_name = data.__class__.__name__
|
|
90
|
+
data_dict = data._asdict()
|
|
91
|
+
fields = data_dict.keys()
|
|
92
|
+
return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
|
|
82
93
|
return tuple(_convert_python_data(x) for x in data)
|
|
83
94
|
if isinstance(data, list):
|
|
84
95
|
# Keep list object not change for inplace operation.
|
|
@@ -86,7 +97,11 @@ def _convert_python_data(data):
|
|
|
86
97
|
data[i] = _convert_python_data(data[i])
|
|
87
98
|
return data
|
|
88
99
|
if isinstance(data, dict):
|
|
89
|
-
|
|
100
|
+
# Keep the dict object not change.
|
|
101
|
+
keys = tuple(data.keys())
|
|
102
|
+
for key in keys:
|
|
103
|
+
data[_convert_python_data(key)] = _convert_python_data(data.pop(key))
|
|
104
|
+
return data
|
|
90
105
|
return data
|
|
91
106
|
|
|
92
107
|
|
|
@@ -175,8 +190,7 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
175
190
|
if isinstance(node, ast.ImportFrom):
|
|
176
191
|
if node.module is not None:
|
|
177
192
|
module_name = node.module
|
|
178
|
-
|
|
179
|
-
module_name = "." + module_name
|
|
193
|
+
module_name = "." * node.level + module_name
|
|
180
194
|
elif not isinstance(node, ast.Import):
|
|
181
195
|
continue
|
|
182
196
|
# Do not care the files in mindspore package
|
|
@@ -284,195 +298,6 @@ def _get_args_for_run(obj, args, kwargs):
|
|
|
284
298
|
return new_args
|
|
285
299
|
|
|
286
300
|
|
|
287
|
-
class _AutoIdentifyDynamicShape:
|
|
288
|
-
|
|
289
|
-
"""
|
|
290
|
-
Represents a function auto identify dynamic shape.
|
|
291
|
-
"""
|
|
292
|
-
def __init__(self):
|
|
293
|
-
self.all_shape_cache = {}
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
@staticmethod
|
|
297
|
-
def get_input_tensor_shape(args_list):
|
|
298
|
-
"""get input tensor shape and type save as tensor, and make it value to 1"""
|
|
299
|
-
tensor_list = []
|
|
300
|
-
for arg in args_list:
|
|
301
|
-
if isinstance(arg, Tensor):
|
|
302
|
-
tmp_shape = arg.shape
|
|
303
|
-
tmp_type = arg.dtype
|
|
304
|
-
tensor_list.append(PythonTensor(np.ones(tmp_shape), dtype=tmp_type))
|
|
305
|
-
else:
|
|
306
|
-
tensor_list.append(arg)
|
|
307
|
-
|
|
308
|
-
return tuple(tensor_list)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
@staticmethod
|
|
312
|
-
def check_input_args(args_list):
|
|
313
|
-
"""check input args"""
|
|
314
|
-
if not args_list:
|
|
315
|
-
return False
|
|
316
|
-
|
|
317
|
-
for elem in args_list:
|
|
318
|
-
if elem is None:
|
|
319
|
-
return False
|
|
320
|
-
|
|
321
|
-
if isinstance(elem, ms.Parameter):
|
|
322
|
-
return False
|
|
323
|
-
|
|
324
|
-
if not isinstance(elem, Tensor):
|
|
325
|
-
return False
|
|
326
|
-
|
|
327
|
-
if elem.const_arg:
|
|
328
|
-
return False
|
|
329
|
-
|
|
330
|
-
return True
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
@staticmethod
|
|
334
|
-
def _is_tensor_equal(input_tensor, cache_tensor):
|
|
335
|
-
"""check two tensor is equal"""
|
|
336
|
-
if input_tensor.dtype != cache_tensor.dtype:
|
|
337
|
-
return False
|
|
338
|
-
|
|
339
|
-
if input_tensor.shape != cache_tensor.shape:
|
|
340
|
-
return False
|
|
341
|
-
|
|
342
|
-
if len(input_tensor.shape) != len(cache_tensor.shape):
|
|
343
|
-
return False
|
|
344
|
-
|
|
345
|
-
return True
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
@staticmethod
|
|
349
|
-
def _is_all_input_shape_generalize(input_shape_tuple):
|
|
350
|
-
"""check all input shapes need generalize"""
|
|
351
|
-
for elem in input_shape_tuple:
|
|
352
|
-
if not is_shape_unknown(elem.shape):
|
|
353
|
-
return False
|
|
354
|
-
return True
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
def auto_dynamic_generate_compile_args(self, args_list):
|
|
358
|
-
"""generate compile args in auto dynamic shape"""
|
|
359
|
-
if not self._is_enable_auto_identify_shape(args_list):
|
|
360
|
-
return args_list
|
|
361
|
-
|
|
362
|
-
args_len = len(args_list)
|
|
363
|
-
tensor_tuple = self.get_input_tensor_shape(args_list)
|
|
364
|
-
shape_cache_list = self.all_shape_cache.get(args_len)
|
|
365
|
-
# step1: init real_shape_cache, part_generalize_shape_cache, all_generalize_shape_cache.
|
|
366
|
-
if shape_cache_list is None:
|
|
367
|
-
shape_cache_list = []
|
|
368
|
-
real_shape_cache = set()
|
|
369
|
-
real_shape_cache.add(tensor_tuple)
|
|
370
|
-
shape_cache_list.append(real_shape_cache)
|
|
371
|
-
part_generalize_shape_cache = set()
|
|
372
|
-
shape_cache_list.append(part_generalize_shape_cache)
|
|
373
|
-
all_generalize_shape_cache = set()
|
|
374
|
-
shape_cache_list.append(all_generalize_shape_cache)
|
|
375
|
-
self.all_shape_cache[args_len] = shape_cache_list
|
|
376
|
-
logger.info((f'The real shape cache is empty, add it into real_shape_cache.'))
|
|
377
|
-
return tensor_tuple
|
|
378
|
-
|
|
379
|
-
# step2: find cache in real_shape_cache.
|
|
380
|
-
real_shape_cache = shape_cache_list[0]
|
|
381
|
-
is_real_shape_exist, real_shape_input = self._find_compile_args_in_shape_cache(real_shape_cache, tensor_tuple,
|
|
382
|
-
"real")
|
|
383
|
-
if is_real_shape_exist and real_shape_input is not None:
|
|
384
|
-
return real_shape_input
|
|
385
|
-
|
|
386
|
-
# step3: if can not find cache in real_shape_cache, then generate it
|
|
387
|
-
is_generalize_shape, compile_args = self._do_generalize_shape(real_shape_cache, tensor_tuple)
|
|
388
|
-
|
|
389
|
-
# step4: if input type change or rank change, save shape into real_shape_cache and then return
|
|
390
|
-
if not is_generalize_shape and compile_args is None:
|
|
391
|
-
real_shape_cache.add(tensor_tuple)
|
|
392
|
-
return tensor_tuple
|
|
393
|
-
|
|
394
|
-
# step5: check whether all input tensor need generalize
|
|
395
|
-
all_generalize_shape_cache = shape_cache_list[2]
|
|
396
|
-
if self._is_all_input_shape_generalize(compile_args):
|
|
397
|
-
if not all_generalize_shape_cache:
|
|
398
|
-
all_generalize_shape_cache.add(compile_args)
|
|
399
|
-
logger.info((f'return all generalize shape cache.'))
|
|
400
|
-
return compile_args
|
|
401
|
-
|
|
402
|
-
# step6: find compile_args in part_generalize_shape_cache
|
|
403
|
-
part_generalize_shape_cache = shape_cache_list[1]
|
|
404
|
-
if not part_generalize_shape_cache:
|
|
405
|
-
part_generalize_shape_cache.add(compile_args)
|
|
406
|
-
else:
|
|
407
|
-
is_generalize_shape_exist, _ = self._find_compile_args_in_shape_cache(part_generalize_shape_cache,
|
|
408
|
-
compile_args, "part generalize")
|
|
409
|
-
if not is_generalize_shape_exist:
|
|
410
|
-
logger.info((f'Can not find cache in part_generalize_shape_cache, add it into'
|
|
411
|
-
' part_generalize_shape_cache.'))
|
|
412
|
-
part_generalize_shape_cache.add(compile_args)
|
|
413
|
-
|
|
414
|
-
return compile_args
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
def _is_all_tensor_equal(self, input_shape_tuple, cache_shape_tuple):
|
|
418
|
-
"""check two tuple is equal"""
|
|
419
|
-
for i, elem in enumerate(cache_shape_tuple):
|
|
420
|
-
res = self._is_tensor_equal(input_shape_tuple[i], elem)
|
|
421
|
-
if not res:
|
|
422
|
-
return False
|
|
423
|
-
return True
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def _is_enable_auto_identify_shape(self, args_list):
|
|
427
|
-
"""is enable auto identify shape"""
|
|
428
|
-
enable_auto_identify = os.getenv('MS_AUTO_DYNAMIC_SHAPE_ENABLE')
|
|
429
|
-
if not enable_auto_identify:
|
|
430
|
-
enable_auto_identify = False
|
|
431
|
-
if ((enable_auto_identify is False or enable_auto_identify == "0")) or not self.check_input_args(args_list):
|
|
432
|
-
return False
|
|
433
|
-
return True
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
def _find_compile_args_in_shape_cache(self, shape_cache, compile_args, cache_type):
|
|
437
|
-
"""find compile args in real or part generalize shape cache"""
|
|
438
|
-
is_exist = False
|
|
439
|
-
for shapes in shape_cache:
|
|
440
|
-
is_exist = self._is_all_tensor_equal(compile_args, shapes)
|
|
441
|
-
if is_exist:
|
|
442
|
-
logger.info((f'Find cache in {cache_type} shape cache.'))
|
|
443
|
-
return is_exist, shapes
|
|
444
|
-
logger.info((f'Can not find cache in {cache_type} shape cache.'))
|
|
445
|
-
return is_exist, None
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def _do_generalize_shape(self, real_shape_cache, tensor_tuple):
|
|
449
|
-
"""do generalize shape"""
|
|
450
|
-
is_generalize_shape = False
|
|
451
|
-
for real_shape in real_shape_cache:
|
|
452
|
-
generalize_shape = []
|
|
453
|
-
for i, elem in enumerate(real_shape):
|
|
454
|
-
if len(elem.shape) != len(tensor_tuple[i].shape) or elem.dtype != tensor_tuple[i].dtype:
|
|
455
|
-
generalize_shape.clear()
|
|
456
|
-
break
|
|
457
|
-
if (not is_shape_unknown(elem.shape)) and self._is_tensor_equal(tensor_tuple[i], elem):
|
|
458
|
-
generalize_shape.append(tensor_tuple[i])
|
|
459
|
-
else:
|
|
460
|
-
shape_value = []
|
|
461
|
-
for _ in range(len(elem.shape)):
|
|
462
|
-
shape_value.append(-1)
|
|
463
|
-
shape_tuple = tuple(shape_value)
|
|
464
|
-
generalize_shape.append(PythonTensor(Tensor(shape=shape_tuple, dtype=tensor_tuple[i].dtype)))
|
|
465
|
-
logger.info((f'The {i} input tensor shape is {tensor_tuple[i].shape}, type is '
|
|
466
|
-
f'{tensor_tuple[i].dtype}; in real cache shape is {elem.shape}, type is '
|
|
467
|
-
f'{elem.dtype}, the {i} input shape not equal, may generalize to {shape_tuple}.'))
|
|
468
|
-
|
|
469
|
-
if len(generalize_shape) == len(real_shape):
|
|
470
|
-
is_generalize_shape = True
|
|
471
|
-
return is_generalize_shape, tuple(generalize_shape)
|
|
472
|
-
|
|
473
|
-
return is_generalize_shape, None
|
|
474
|
-
|
|
475
|
-
|
|
476
301
|
class _MindsporeFunctionExecutor:
|
|
477
302
|
"""
|
|
478
303
|
Represents a function compiled by graph compiler.
|
|
@@ -490,7 +315,6 @@ class _MindsporeFunctionExecutor:
|
|
|
490
315
|
Returns:
|
|
491
316
|
The result of pipeline running in graph mode.
|
|
492
317
|
"""
|
|
493
|
-
|
|
494
318
|
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
|
|
495
319
|
init_pipeline()
|
|
496
320
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
@@ -506,7 +330,7 @@ class _MindsporeFunctionExecutor:
|
|
|
506
330
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
507
331
|
self._create_time = ms_create_time
|
|
508
332
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
509
|
-
|
|
333
|
+
|
|
510
334
|
|
|
511
335
|
@_wrap_func
|
|
512
336
|
def __call__(self, *args, **kwargs):
|
|
@@ -516,9 +340,9 @@ class _MindsporeFunctionExecutor:
|
|
|
516
340
|
phase = ""
|
|
517
341
|
try:
|
|
518
342
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
519
|
-
_pynative_executor.
|
|
343
|
+
_pynative_executor.set_jit_compile_status(True, phase)
|
|
520
344
|
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
521
|
-
_pynative_executor.
|
|
345
|
+
_pynative_executor.set_jit_compile_status(False, phase)
|
|
522
346
|
else:
|
|
523
347
|
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
524
348
|
except Exception as err:
|
|
@@ -531,19 +355,11 @@ class _MindsporeFunctionExecutor:
|
|
|
531
355
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
532
356
|
output = self._graph_executor(tuple(new_inputs), phase)
|
|
533
357
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
534
|
-
output = _pynative_executor.
|
|
535
|
-
|
|
536
|
-
enable_ge = os.getenv("MS_ENABLE_GE") == "1"
|
|
537
|
-
if enable_ge and self.jit_config_dict is None:
|
|
538
|
-
raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
|
|
539
|
-
if self.jit_config_dict:
|
|
540
|
-
enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
|
|
541
|
-
if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
|
|
542
|
-
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jit_level={}".
|
|
543
|
-
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
358
|
+
output = _pynative_executor.grad_jit(output, *new_inputs)
|
|
544
359
|
|
|
545
360
|
return output
|
|
546
361
|
|
|
362
|
+
|
|
547
363
|
def compile(self, method_name, *args, **kwargs):
|
|
548
364
|
"""Returns pipeline for the given args."""
|
|
549
365
|
# Check whether hook function registered on Cell object.
|
|
@@ -554,14 +370,16 @@ class _MindsporeFunctionExecutor:
|
|
|
554
370
|
f"pynative mode and remove 'jit' decorator.")
|
|
555
371
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
556
372
|
compile_args = self._generate_compile_args(args)
|
|
373
|
+
key_id = self._get_key_id()
|
|
374
|
+
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
|
|
375
|
+
self.input_signature)
|
|
376
|
+
|
|
557
377
|
# Restore the mutable attr for every arg.
|
|
558
378
|
compile_args = _restore_mutable_attr(args, compile_args)
|
|
559
|
-
generate_name = self.
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
if _is_pynative_parallel():
|
|
564
|
-
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
379
|
+
generate_name, echo_function_name = self._get_generate_name()
|
|
380
|
+
# The full Function name
|
|
381
|
+
full_function_name = generate_name
|
|
382
|
+
create_time = ''
|
|
565
383
|
|
|
566
384
|
# Add key with obj
|
|
567
385
|
if self.obj is not None:
|
|
@@ -572,13 +390,18 @@ class _MindsporeFunctionExecutor:
|
|
|
572
390
|
self.obj.__parse_method__ = method_name
|
|
573
391
|
if isinstance(self.obj, ms.nn.Cell):
|
|
574
392
|
generate_name = generate_name + '.' + str(self.obj.create_time)
|
|
393
|
+
create_time = str(self.obj.create_time)
|
|
575
394
|
else:
|
|
576
395
|
generate_name = generate_name + '.' + str(self._create_time)
|
|
396
|
+
create_time = str(self._create_time)
|
|
397
|
+
|
|
577
398
|
generate_name = generate_name + '.' + str(id(self.obj))
|
|
399
|
+
full_function_name = generate_name
|
|
578
400
|
else:
|
|
579
401
|
# Different instance of same class may use same memory(means same obj_id) at diff times.
|
|
580
402
|
# To avoid unexpected phase matched, add create_time to generate_name.
|
|
581
403
|
generate_name = generate_name + '.' + str(self._create_time)
|
|
404
|
+
create_time = str(self._create_time)
|
|
582
405
|
|
|
583
406
|
self.enable_tuple_broaden = False
|
|
584
407
|
if hasattr(self.obj, "enable_tuple_broaden"):
|
|
@@ -587,16 +410,33 @@ class _MindsporeFunctionExecutor:
|
|
|
587
410
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
588
411
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
589
412
|
phase = generate_name + '.' + str(key)
|
|
413
|
+
|
|
414
|
+
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
415
|
+
|
|
590
416
|
if phase in ms_compile_cache:
|
|
591
417
|
return phase
|
|
592
418
|
|
|
419
|
+
self._check_recompile(full_function_name, create_time, echo_function_name)
|
|
420
|
+
|
|
593
421
|
# If enable compile cache, get the dependency files list and set to graph executor.
|
|
594
422
|
self._set_compile_cache_dep_files()
|
|
595
423
|
if self.jit_config_dict:
|
|
596
424
|
self._graph_executor.set_jit_config(self.jit_config_dict)
|
|
425
|
+
else:
|
|
426
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
427
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
597
428
|
|
|
598
429
|
if self.obj is None:
|
|
430
|
+
# Set an attribute to fn as an identifier.
|
|
431
|
+
if isinstance(self.fn, types.MethodType):
|
|
432
|
+
setattr(self.fn.__func__, "__jit_function__", True)
|
|
433
|
+
else:
|
|
434
|
+
setattr(self.fn, "__jit_function__", True)
|
|
599
435
|
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
|
|
436
|
+
if isinstance(self.fn, types.MethodType):
|
|
437
|
+
delattr(self.fn.__func__, "__jit_function__")
|
|
438
|
+
else:
|
|
439
|
+
delattr(self.fn, "__jit_function__")
|
|
600
440
|
else:
|
|
601
441
|
if isinstance(self.obj, ms.nn.Cell):
|
|
602
442
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
@@ -605,8 +445,32 @@ class _MindsporeFunctionExecutor:
|
|
|
605
445
|
if not is_compile:
|
|
606
446
|
raise RuntimeError("Executor compile failed.")
|
|
607
447
|
ms_compile_cache.add(phase)
|
|
448
|
+
|
|
608
449
|
return phase
|
|
609
450
|
|
|
451
|
+
def _check_recompile(self, full_function_name, create_time, echo_function_name):
|
|
452
|
+
"""Warning when the function has been compiled."""
|
|
453
|
+
ignore_dirs = ["mindspore/ops", "mindspore/nn"]
|
|
454
|
+
if any((lambda x: x in full_function_name)(x) for x in ignore_dirs):
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
if full_function_name in function_phases:
|
|
458
|
+
warning_times = 1
|
|
459
|
+
if len(function_phases[full_function_name]) >= warning_times \
|
|
460
|
+
and create_time not in function_phases[full_function_name]:
|
|
461
|
+
tips = "Try to decorate the function with @jit(hash_args=...) " \
|
|
462
|
+
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
463
|
+
"For more details, get instructions about `jit` at " \
|
|
464
|
+
"https://www.mindspore.cn/search?inputValue=jit."
|
|
465
|
+
|
|
466
|
+
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
467
|
+
f"{tips} ")
|
|
468
|
+
else:
|
|
469
|
+
function_phases[full_function_name] = set()
|
|
470
|
+
|
|
471
|
+
function_phases[full_function_name].add(create_time)
|
|
472
|
+
|
|
473
|
+
|
|
610
474
|
@staticmethod
|
|
611
475
|
def _optimizer_state_init(opt_states):
|
|
612
476
|
"""set data for all optimizer states in case it is executed in graph mode"""
|
|
@@ -618,6 +482,31 @@ class _MindsporeFunctionExecutor:
|
|
|
618
482
|
opt_param.init_data()
|
|
619
483
|
|
|
620
484
|
|
|
485
|
+
def _get_key_id(self):
|
|
486
|
+
"""get key id."""
|
|
487
|
+
if isinstance(self.obj, ms.nn.Cell):
|
|
488
|
+
key_id = str(id(self.obj)) + str(self.obj.create_time)
|
|
489
|
+
else:
|
|
490
|
+
key_id = str(id(self.obj)) + str(self._create_time)
|
|
491
|
+
|
|
492
|
+
if _pynative_executor.grad_flag():
|
|
493
|
+
key_id = key_id + ".grad"
|
|
494
|
+
return key_id
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def _get_generate_name(self):
|
|
498
|
+
"""get generate name."""
|
|
499
|
+
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + str(
|
|
500
|
+
self.fn.__code__.co_firstlineno)
|
|
501
|
+
echo_function_name = "function \"" + self.fn.__name__ + "\" at the file \"" + self.fn.__code__.co_filename \
|
|
502
|
+
+ "\", line " + str(self.fn.__code__.co_firstlineno)
|
|
503
|
+
if _pynative_executor.grad_flag():
|
|
504
|
+
generate_name = generate_name + ".grad"
|
|
505
|
+
if _is_pynative_parallel():
|
|
506
|
+
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
507
|
+
return generate_name, echo_function_name
|
|
508
|
+
|
|
509
|
+
|
|
621
510
|
def _set_compile_cache_dep_files(self):
|
|
622
511
|
# If enable compile cache, get the dependency files list
|
|
623
512
|
enable_compile_cache = context.get_context("enable_compile_cache")
|
|
@@ -630,7 +519,7 @@ class _MindsporeFunctionExecutor:
|
|
|
630
519
|
def _generate_compile_args(self, args_list):
|
|
631
520
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
632
521
|
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
633
|
-
compile_args = args_list
|
|
522
|
+
compile_args = _pynative_executor.get_dynamic_input(args_list)
|
|
634
523
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
635
524
|
if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
|
636
525
|
compile_args = self.obj.get_inputs()
|
|
@@ -659,12 +548,13 @@ class _MindsporeFunctionExecutor:
|
|
|
659
548
|
f"be 'sens' and added it to compile args.")
|
|
660
549
|
self.input_signature.append(args_list[-1])
|
|
661
550
|
compile_args = tuple(self.input_signature)
|
|
662
|
-
|
|
551
|
+
if self.obj is not None:
|
|
552
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
553
|
+
else:
|
|
554
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
663
555
|
else:
|
|
664
556
|
if not verify_inputs_signature(self.input_signature, args_list):
|
|
665
557
|
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
666
|
-
else:
|
|
667
|
-
compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args_list)
|
|
668
558
|
return compile_args
|
|
669
559
|
|
|
670
560
|
def _generate_run_args(self, args_list, kwargs):
|
|
@@ -699,14 +589,14 @@ def _get_obj_id(input_obj):
|
|
|
699
589
|
return obj_id + str(id(input_obj))
|
|
700
590
|
|
|
701
591
|
|
|
702
|
-
def
|
|
592
|
+
def _get_jit_hash(hash_input):
|
|
703
593
|
"""Get hash value of single object or list of objects."""
|
|
704
594
|
if isinstance(list, tuple):
|
|
705
595
|
return ".".join(map(_get_obj_id, hash_input))
|
|
706
596
|
return _get_obj_id(hash_input)
|
|
707
597
|
|
|
708
598
|
|
|
709
|
-
def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
599
|
+
def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_once=False):
|
|
710
600
|
"""
|
|
711
601
|
Create a callable MindSpore graph from a Python function.
|
|
712
602
|
|
|
@@ -726,6 +616,10 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
726
616
|
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
727
617
|
will trigger recompilation. Default: ``None`` .
|
|
728
618
|
jit_config (JitConfig): Jit config for compile. Default: ``None`` .
|
|
619
|
+
compile_once(bool): ``True``: The function would be compiled once when it was created many times.
|
|
620
|
+
But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
|
|
621
|
+
it was created again
|
|
622
|
+
Default: ``False`` .
|
|
729
623
|
|
|
730
624
|
Returns:
|
|
731
625
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -769,7 +663,7 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
769
663
|
...
|
|
770
664
|
>>> out = tensor_add_with_sig(x, y)
|
|
771
665
|
...
|
|
772
|
-
... # Set hash_args as fn, otherwise cache of compiled
|
|
666
|
+
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
|
|
773
667
|
... # While fn differs during calling again, recompilation will be triggered.
|
|
774
668
|
>>> def func(x):
|
|
775
669
|
... return ops.exp(x)
|
|
@@ -783,11 +677,28 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
783
677
|
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
784
678
|
>>> for i in range(10):
|
|
785
679
|
... closure_fn(inputs, func)
|
|
680
|
+
...
|
|
681
|
+
... # Set compile_once = True, otherwise the train_step will be compiled again.
|
|
682
|
+
>>> def train(x):
|
|
683
|
+
... @jit(compile_once = True)
|
|
684
|
+
... def train_step(x):
|
|
685
|
+
... return ops.exp(x)
|
|
686
|
+
... for i in range(10):
|
|
687
|
+
... train_step(x)
|
|
688
|
+
...
|
|
689
|
+
>>> inputs = Tensor(np.ones([10, 10, 10]).astype(np.float32))
|
|
690
|
+
>>> for i in range(10):
|
|
691
|
+
... train(inputs)
|
|
786
692
|
"""
|
|
787
693
|
|
|
788
694
|
def wrap_mindspore(func):
|
|
695
|
+
if not isinstance(compile_once, bool):
|
|
696
|
+
logger.warning(f"The parameter `compile_once` of jit should be a bool, "
|
|
697
|
+
f"but got {type(compile_once)}.")
|
|
789
698
|
if hash_args:
|
|
790
|
-
hash_obj =
|
|
699
|
+
hash_obj = _get_jit_hash(hash_args)
|
|
700
|
+
elif compile_once:
|
|
701
|
+
hash_obj = 0
|
|
791
702
|
else:
|
|
792
703
|
hash_obj = int(time.time() * 1e9)
|
|
793
704
|
|
|
@@ -984,8 +895,8 @@ def _no_recursive(callable_obj):
|
|
|
984
895
|
Supported Platforms:
|
|
985
896
|
``Ascend`` ``GPU`` ``CPU``
|
|
986
897
|
"""
|
|
987
|
-
|
|
988
|
-
if not
|
|
898
|
+
is_cell_subclass = inspect.isclass(callable_obj) and issubclass(callable_obj, ms.nn.Cell)
|
|
899
|
+
if not is_cell_subclass and not inspect.ismethod(callable_obj) and not inspect.isfunction(callable_obj):
|
|
989
900
|
raise TypeError(f"Decorator no_recursive is used for callable object, but got {callable_obj}.")
|
|
990
901
|
_add_flags(callable_obj, no_recursive=True)
|
|
991
902
|
return callable_obj
|
|
@@ -1149,7 +1060,7 @@ def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
|
|
|
1149
1060
|
_broadcast_net.phase = broadcast_phase
|
|
1150
1061
|
broadcasted_params = _broadcast_net()
|
|
1151
1062
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
|
1152
|
-
broadcast_params_dict
|
|
1063
|
+
broadcast_params_dict.get(param_name).set_data(param)
|
|
1153
1064
|
|
|
1154
1065
|
|
|
1155
1066
|
def _get_auto_split_param_names(parameter_layout_dict):
|
|
@@ -1355,30 +1266,18 @@ class _PyNativeExecutor:
|
|
|
1355
1266
|
"""
|
|
1356
1267
|
self._executor.sync()
|
|
1357
1268
|
|
|
1358
|
-
def
|
|
1269
|
+
def grad_jit(self, output, *args):
|
|
1359
1270
|
"""
|
|
1360
|
-
|
|
1271
|
+
Building grad graph decorated by jit.
|
|
1361
1272
|
|
|
1362
1273
|
Args:
|
|
1363
|
-
|
|
1274
|
+
output (tuple): The function or cell decorated by jit output object.
|
|
1275
|
+
args (tuple): Function or cell decorated by jit input arguments.
|
|
1364
1276
|
|
|
1365
1277
|
Return:
|
|
1366
1278
|
None.
|
|
1367
1279
|
"""
|
|
1368
|
-
self._executor.
|
|
1369
|
-
|
|
1370
|
-
def grad_ms_function(self, output, *args):
|
|
1371
|
-
"""
|
|
1372
|
-
Building grad graph decorated by ms_function.
|
|
1373
|
-
|
|
1374
|
-
Args:
|
|
1375
|
-
output (tuple): The function or cell decorated by ms_function output object.
|
|
1376
|
-
args (tuple): Function or cell decorated by ms_function input arguments.
|
|
1377
|
-
|
|
1378
|
-
Return:
|
|
1379
|
-
None.
|
|
1380
|
-
"""
|
|
1381
|
-
return self._executor.grad_ms_function(output, *args)
|
|
1280
|
+
return self._executor.grad_jit(output, *args)
|
|
1382
1281
|
|
|
1383
1282
|
def grad_flag(self):
|
|
1384
1283
|
"""
|
|
@@ -1422,29 +1321,42 @@ class _PyNativeExecutor:
|
|
|
1422
1321
|
"""
|
|
1423
1322
|
self._executor.set_enable_grad(flag)
|
|
1424
1323
|
|
|
1425
|
-
def
|
|
1324
|
+
def set_jit_compile_status(self, status, phase):
|
|
1426
1325
|
"""
|
|
1427
|
-
Set
|
|
1326
|
+
Set jit is compiling
|
|
1428
1327
|
|
|
1429
1328
|
Args:
|
|
1430
|
-
status(bool):
|
|
1329
|
+
status(bool): jit compile status
|
|
1431
1330
|
phase (str): The phase of cell/function instance.
|
|
1432
1331
|
Return:
|
|
1433
1332
|
None.
|
|
1434
1333
|
"""
|
|
1435
|
-
self._executor.
|
|
1334
|
+
self._executor.set_jit_compile_status(status, phase)
|
|
1436
1335
|
|
|
1437
|
-
def set_dynamic_input(self, obj):
|
|
1336
|
+
def set_dynamic_input(self, obj, *args):
|
|
1438
1337
|
"""
|
|
1439
1338
|
Set dynamic shape tensor of input arguments.
|
|
1440
1339
|
|
|
1441
1340
|
Args:
|
|
1442
1341
|
obj (Function/Cell): The function or cell instance.
|
|
1342
|
+
args (tuple): Function or cell dynamic input arguments.
|
|
1443
1343
|
|
|
1444
1344
|
Return:
|
|
1445
1345
|
None.
|
|
1446
1346
|
"""
|
|
1447
|
-
self._executor.set_dynamic_input(obj)
|
|
1347
|
+
self._executor.set_dynamic_input(obj, *args)
|
|
1348
|
+
|
|
1349
|
+
def get_dynamic_input(self, *actual_args):
|
|
1350
|
+
"""
|
|
1351
|
+
Get dynamic shape arguments according to actual input arguments.
|
|
1352
|
+
|
|
1353
|
+
Args:
|
|
1354
|
+
actual_args(tuple): Actual input arguments of Function or Cell.
|
|
1355
|
+
|
|
1356
|
+
Return:
|
|
1357
|
+
dynamic_shape_args(tuple): Dynamic shape arguments of Function or Cell.
|
|
1358
|
+
"""
|
|
1359
|
+
return self._executor.get_dynamic_input(*actual_args)
|
|
1448
1360
|
|
|
1449
1361
|
def is_first_cell(self):
|
|
1450
1362
|
"""
|
|
@@ -1550,6 +1462,13 @@ class _CellGraphExecutor:
|
|
|
1550
1462
|
"""
|
|
1551
1463
|
self._graph_executor.set_queue_name(queue_name)
|
|
1552
1464
|
|
|
1465
|
+
def get_queue_name(self, dataset_phase):
|
|
1466
|
+
"""
|
|
1467
|
+
Get cached queue name for the graph loaded from compile cache.
|
|
1468
|
+
:return: cached queue name
|
|
1469
|
+
"""
|
|
1470
|
+
return self._graph_executor.get_queue_name(dataset_phase)
|
|
1471
|
+
|
|
1553
1472
|
@staticmethod
|
|
1554
1473
|
def _set_dataset_mode(obj):
|
|
1555
1474
|
"""set dataset mode."""
|
|
@@ -1597,15 +1516,18 @@ class _CellGraphExecutor:
|
|
|
1597
1516
|
if not hasattr(obj, obj.__parse_method__):
|
|
1598
1517
|
raise AttributeError(
|
|
1599
1518
|
'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
1519
|
+
key_id = str(id(obj)) + str(obj.create_time)
|
|
1520
|
+
args = get_auto_dynamic_shape_args(args, key_id)
|
|
1600
1521
|
|
|
1601
1522
|
self.enable_tuple_broaden = False
|
|
1602
1523
|
if hasattr(obj, "enable_tuple_broaden"):
|
|
1603
1524
|
self.enable_tuple_broaden = obj.enable_tuple_broaden
|
|
1604
|
-
|
|
1525
|
+
logger.debug("Convert the network.", do_convert)
|
|
1605
1526
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1606
1527
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1607
1528
|
obj.arguments_key = str(key)
|
|
1608
1529
|
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
1530
|
+
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
1609
1531
|
|
|
1610
1532
|
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
1611
1533
|
logger.debug("%r graph has existed.", phase)
|
|
@@ -1616,12 +1538,12 @@ class _CellGraphExecutor:
|
|
|
1616
1538
|
self._set_dataset_mode(obj)
|
|
1617
1539
|
self._set_compile_cache_dep_files(phase)
|
|
1618
1540
|
|
|
1619
|
-
enable_ge = context.get_context("enable_ge")
|
|
1620
|
-
if enable_ge:
|
|
1621
|
-
obj.add_flags(ge_init=True)
|
|
1622
1541
|
self._graph_executor.set_weights_values(obj.parameters_dict())
|
|
1623
1542
|
if jit_config_dict:
|
|
1624
1543
|
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1544
|
+
else:
|
|
1545
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
1546
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1625
1547
|
result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
|
|
1626
1548
|
obj.compile_cache.add(phase)
|
|
1627
1549
|
if not result:
|
|
@@ -1639,17 +1561,10 @@ class _CellGraphExecutor:
|
|
|
1639
1561
|
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
|
|
1640
1562
|
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
|
|
1641
1563
|
|
|
1642
|
-
if
|
|
1643
|
-
return phase, True
|
|
1644
|
-
|
|
1645
|
-
# the following GE init process is not needed when use vm or ms backend
|
|
1646
|
-
if enable_ge:
|
|
1647
|
-
pass
|
|
1648
|
-
elif "export" in phase:
|
|
1564
|
+
if "export.air" in phase:
|
|
1649
1565
|
self._build_data_graph(obj, phase)
|
|
1650
1566
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
1651
1567
|
_parameter_broadcast(obj)
|
|
1652
|
-
|
|
1653
1568
|
return phase, True
|
|
1654
1569
|
|
|
1655
1570
|
def _update_param_node_default_input(self, phase, replace):
|