mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -38,12 +38,13 @@ from mindspore.common.tensor import Tensor as PythonTensor
|
|
|
38
38
|
from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
39
39
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
40
40
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
41
|
+
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
41
42
|
from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
|
|
42
43
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
43
44
|
_ms_memory_recycle, _bind_device_ctx
|
|
44
45
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
45
46
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
|
46
|
-
_is_in_auto_parallel_mode
|
|
47
|
+
_is_in_auto_parallel_mode, _is_parallel_mode
|
|
47
48
|
from mindspore import _checkparam as Validator
|
|
48
49
|
from mindspore._checkparam import is_stub_tensor
|
|
49
50
|
from mindspore.common._utils import is_shape_unknown
|
|
@@ -51,6 +52,8 @@ from mindspore.common.mutable import mutable
|
|
|
51
52
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
52
53
|
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
53
54
|
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
55
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
56
|
+
from mindspore.common.parameter import Parameter
|
|
54
57
|
|
|
55
58
|
# Store ms_function class compiled pipeline cache.
|
|
56
59
|
ms_compile_cache = set()
|
|
@@ -62,6 +65,68 @@ function_phases = dict()
|
|
|
62
65
|
BROADCAST_PHASE = "_broadcast_"
|
|
63
66
|
_PYNATIVE_PARALLEL_FUNC_NAME = "after_shard"
|
|
64
67
|
|
|
68
|
+
ARG_SPECIFIED = "arg_specified_infos"
|
|
69
|
+
TOTAL_ARG_LEN = "total_arg_length"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _check_recompile_args(compile_args, kwargs):
|
|
73
|
+
"""Check recompile of graph"""
|
|
74
|
+
|
|
75
|
+
def _check_constant_tensor_arg(arg):
|
|
76
|
+
if hasattr(arg, "__ms_mutable__"):
|
|
77
|
+
return False
|
|
78
|
+
if isinstance(arg, (list, tuple)):
|
|
79
|
+
return any(_check_constant_tensor_arg(x) for x in arg)
|
|
80
|
+
return isinstance(arg, Tensor)
|
|
81
|
+
|
|
82
|
+
for v in kwargs.values():
|
|
83
|
+
compile_args += (v,)
|
|
84
|
+
for arg in compile_args:
|
|
85
|
+
if not isinstance(arg, tuple) and not isinstance(arg, list):
|
|
86
|
+
continue
|
|
87
|
+
if _check_constant_tensor_arg(arg):
|
|
88
|
+
logger.warning(f"Constant value tensor are detected in tuple or list, which might cause recompiling "
|
|
89
|
+
f"when tensor value changes. You can use mutable(Tensor) or mutable(tuple(Tensor)) "
|
|
90
|
+
f"to set tensor's value as variable to to avoid recompiling. The tuple or list arg "
|
|
91
|
+
f"is: {arg} .")
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _check_recompile(obj, compile_args, kwargs, full_function_name, create_time, echo_function_name):
|
|
96
|
+
"""Warning when the function has been compiled."""
|
|
97
|
+
ignore_dirs = ["mindspore/ops", "mindspore/nn"]
|
|
98
|
+
if any((lambda x: x in full_function_name)(x) for x in ignore_dirs):
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
if full_function_name in function_phases:
|
|
102
|
+
warning_times = 1
|
|
103
|
+
if len(function_phases[full_function_name]) >= warning_times \
|
|
104
|
+
and create_time not in function_phases[full_function_name]:
|
|
105
|
+
if isinstance(obj, ms.nn.Cell):
|
|
106
|
+
tips = f"Please try to create {echo_function_name} instance only once to avoid recompiling. "
|
|
107
|
+
logger.info(f"The {echo_function_name} has been compiled again. "
|
|
108
|
+
f"{tips} ")
|
|
109
|
+
else:
|
|
110
|
+
tips = "Try to decorate the function with @jit(hash_args=...) " \
|
|
111
|
+
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
112
|
+
"For more details, get instructions about `jit` at " \
|
|
113
|
+
"https://www.mindspore.cn/search?inputValue=jit."
|
|
114
|
+
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
115
|
+
f"{tips} ")
|
|
116
|
+
else:
|
|
117
|
+
_check_recompile_args(compile_args, kwargs)
|
|
118
|
+
else:
|
|
119
|
+
function_phases[full_function_name] = set()
|
|
120
|
+
function_phases[full_function_name].add(create_time)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _ms_adapter_tensor_as_parameter_output(data):
|
|
124
|
+
"""Check whether the data is an output from a parameter which is a ms_adapter tensor.
|
|
125
|
+
Pylint: disable=unidiomatic-typecheck.
|
|
126
|
+
"""
|
|
127
|
+
return ms_adapter_registry.is_registered and isinstance(data, ms_adapter_registry.tensor) \
|
|
128
|
+
and hasattr(data, "__ms_parameter_output__") and getattr(data, "__ms_parameter_output__")
|
|
129
|
+
|
|
65
130
|
|
|
66
131
|
def _convert_python_data(data):
|
|
67
132
|
"""
|
|
@@ -73,8 +138,10 @@ def _convert_python_data(data):
|
|
|
73
138
|
Returns:
|
|
74
139
|
data, a data convert C++ to python
|
|
75
140
|
"""
|
|
76
|
-
if isinstance(data, Tensor) and data.adapter_flag:
|
|
141
|
+
if isinstance(data, (Tensor, PythonTensor)) and data.adapter_flag:
|
|
77
142
|
return ms_adapter_registry.tensor(data)
|
|
143
|
+
if _ms_adapter_tensor_as_parameter_output(data) and hasattr(data, "tensor"):
|
|
144
|
+
return data.tensor
|
|
78
145
|
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
|
79
146
|
return PythonTensor(data, internal=True)
|
|
80
147
|
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
@@ -83,7 +150,7 @@ def _convert_python_data(data):
|
|
|
83
150
|
return PythonCOOTensor(coo_tensor=data)
|
|
84
151
|
if isinstance(data, RowTensor) and not isinstance(data, PythonRowTensor):
|
|
85
152
|
return PythonRowTensor(row_tensor=data)
|
|
86
|
-
if
|
|
153
|
+
if data.__class__ is tuple:
|
|
87
154
|
# Handle namedtuple since its type is tuple.
|
|
88
155
|
if hasattr(data, "_fields"):
|
|
89
156
|
type_name = data.__class__.__name__
|
|
@@ -91,12 +158,12 @@ def _convert_python_data(data):
|
|
|
91
158
|
fields = data_dict.keys()
|
|
92
159
|
return namedtuple(type_name, fields)(**_convert_python_data(data_dict))
|
|
93
160
|
return tuple(_convert_python_data(x) for x in data)
|
|
94
|
-
if
|
|
161
|
+
if data.__class__ is list:
|
|
95
162
|
# Keep list object not change for inplace operation.
|
|
96
163
|
for i in range(len(data)):
|
|
97
164
|
data[i] = _convert_python_data(data[i])
|
|
98
165
|
return data
|
|
99
|
-
if
|
|
166
|
+
if data.__class__ is dict:
|
|
100
167
|
# Keep the dict object not change.
|
|
101
168
|
keys = tuple(data.keys())
|
|
102
169
|
for key in keys:
|
|
@@ -167,6 +234,7 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
167
234
|
|
|
168
235
|
sys_path = list(sys.path)
|
|
169
236
|
# Get the entry script path.
|
|
237
|
+
entry_script_path = None
|
|
170
238
|
if sys.argv and sys.argv[0] != '':
|
|
171
239
|
entry_script_path = os.path.realpath(sys.argv[0])
|
|
172
240
|
entry_script_path_dir = os.path.split(entry_script_path)[0]
|
|
@@ -260,7 +328,7 @@ def _get_parameter_layout():
|
|
|
260
328
|
return layout
|
|
261
329
|
|
|
262
330
|
|
|
263
|
-
def _handle_arg(obj, arg):
|
|
331
|
+
def _handle_arg(obj, arg, compile_arg):
|
|
264
332
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
265
333
|
if isinstance(arg, PythonTensor):
|
|
266
334
|
if arg.has_init:
|
|
@@ -269,7 +337,7 @@ def _handle_arg(obj, arg):
|
|
|
269
337
|
return arg
|
|
270
338
|
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
|
271
339
|
return arg
|
|
272
|
-
elif hasattr(
|
|
340
|
+
elif compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and getattr(compile_arg, "__ms_mutable__"):
|
|
273
341
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
274
342
|
if isinstance(arg, list) and not arg:
|
|
275
343
|
return None
|
|
@@ -282,22 +350,185 @@ def _handle_arg(obj, arg):
|
|
|
282
350
|
return None
|
|
283
351
|
|
|
284
352
|
|
|
285
|
-
def
|
|
353
|
+
def _handle_arg_predict(obj, arg, compile_arg):
|
|
354
|
+
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
355
|
+
if arg is None:
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
if isinstance(arg, (int, float)):
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
if isinstance(arg, (list, tuple)):
|
|
362
|
+
if compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
|
|
363
|
+
getattr(compile_arg, "__ms_mutable__"):
|
|
364
|
+
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
365
|
+
if isinstance(arg, list) and not arg:
|
|
366
|
+
return None
|
|
367
|
+
return arg
|
|
368
|
+
if hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
|
369
|
+
_check_all_tensor(arg):
|
|
370
|
+
return arg
|
|
371
|
+
return None
|
|
372
|
+
return arg
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _get_args_for_run(obj, args, kwargs, compile_args):
|
|
286
376
|
"""Get the actual input args and kwargs for runtime."""
|
|
287
377
|
new_args = []
|
|
288
|
-
for arg in args:
|
|
289
|
-
new_arg = _handle_arg(obj, arg)
|
|
378
|
+
for arg, compile_arg in zip(args, compile_args):
|
|
379
|
+
new_arg = _handle_arg(obj, arg, compile_arg)
|
|
290
380
|
if new_arg is not None:
|
|
291
381
|
new_args.append(new_arg)
|
|
292
382
|
|
|
293
383
|
for _, value in kwargs.items():
|
|
294
|
-
new_value = _handle_arg(obj, value)
|
|
384
|
+
new_value = _handle_arg(obj, value, None)
|
|
295
385
|
if new_value is not None:
|
|
296
386
|
new_args.append(new_value)
|
|
297
387
|
|
|
298
388
|
return new_args
|
|
299
389
|
|
|
300
390
|
|
|
391
|
+
def _get_args_for_run_predict(obj, args, kwargs, compile_args):
|
|
392
|
+
"""Get the actual input args and kwargs for runtime."""
|
|
393
|
+
new_args = []
|
|
394
|
+
for arg, compile_arg in zip(args, compile_args):
|
|
395
|
+
new_arg = _handle_arg_predict(obj, arg, compile_arg)
|
|
396
|
+
if new_arg is not None:
|
|
397
|
+
new_args.append(new_arg)
|
|
398
|
+
|
|
399
|
+
for _, value in kwargs.items():
|
|
400
|
+
new_value = _handle_arg_predict(obj, value, None)
|
|
401
|
+
if new_value is not None:
|
|
402
|
+
new_args.append(new_value)
|
|
403
|
+
|
|
404
|
+
return new_args
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _is_args_fullmode(args, is_init=True):
|
|
408
|
+
"""Check whether the arguments is for incremental-mode.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
args (Union[list, tuple, dict, Tensor]): Given arguments.
|
|
412
|
+
is_init (bool): Is check in argument initialization phase.
|
|
413
|
+
|
|
414
|
+
Raises:
|
|
415
|
+
RuntimeError: loss necessary keys and values for incremental-mode.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
bool: Fullmode or not.
|
|
419
|
+
"""
|
|
420
|
+
if not isinstance(args, dict):
|
|
421
|
+
return True
|
|
422
|
+
if not is_init and (args.get(ARG_SPECIFIED, None) is None or args.get(TOTAL_ARG_LEN, None) is None):
|
|
423
|
+
raise RuntimeError(
|
|
424
|
+
"The incremental inputs should be processed(with \"%s\" and \"%s\"), but got %s." %
|
|
425
|
+
(ARG_SPECIFIED, TOTAL_ARG_LEN, str(args)))
|
|
426
|
+
return False
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _process_dyn_args(fn, dyn_args):
|
|
430
|
+
"""Process the dynamic arguments, return the necessary data for latter processing.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
fn (Function): The root function to compile.
|
|
434
|
+
dyn_args (Union[dict, list, tuple, None]): Given arguments for dynamic compilation.
|
|
435
|
+
None for nothing, list or tuple for fullmode setting, dict for incremental configuration.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
A dict which contains args for dynamic compilation. None for nothing dynamic.
|
|
439
|
+
"""
|
|
440
|
+
if dyn_args is None:
|
|
441
|
+
# nothing should be done for None.
|
|
442
|
+
return dyn_args
|
|
443
|
+
|
|
444
|
+
if isinstance(dyn_args, dict) and ARG_SPECIFIED in dyn_args:
|
|
445
|
+
return dyn_args
|
|
446
|
+
|
|
447
|
+
args_sig = inspect.signature(fn)
|
|
448
|
+
if _is_args_fullmode(dyn_args):
|
|
449
|
+
if not isinstance(dyn_args, (list, tuple)):
|
|
450
|
+
temp_dyn_args = (dyn_args,)
|
|
451
|
+
else:
|
|
452
|
+
temp_dyn_args = dyn_args
|
|
453
|
+
|
|
454
|
+
# If dyn_args is fullmode, it should be apply directly.
|
|
455
|
+
args_sig_parameters = list(args_sig.parameters.values())
|
|
456
|
+
if not args_sig_parameters:
|
|
457
|
+
return ()
|
|
458
|
+
|
|
459
|
+
# fn may be Cell's construct while the first input is 'self'.
|
|
460
|
+
if args_sig_parameters[0].name == "self" and (len(temp_dyn_args) + 1) == len(args_sig_parameters):
|
|
461
|
+
bound_args = args_sig.bind(None, *temp_dyn_args)
|
|
462
|
+
bound_args.apply_defaults()
|
|
463
|
+
return bound_args.args[1:]
|
|
464
|
+
|
|
465
|
+
bound_args = args_sig.bind(*temp_dyn_args)
|
|
466
|
+
bound_args.apply_defaults()
|
|
467
|
+
return bound_args.args
|
|
468
|
+
|
|
469
|
+
# The dyn_args is not fullmode, a real compilation arguments should be assembled by latter procession...
|
|
470
|
+
arg_names = []
|
|
471
|
+
args_sig_parameters = list(args_sig.parameters.values())
|
|
472
|
+
for arg_p in args_sig_parameters:
|
|
473
|
+
if arg_p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
|
|
474
|
+
arg_names.append(arg_p.name)
|
|
475
|
+
else:
|
|
476
|
+
raise TypeError("Dynamic arguments is not accepted for VAR_POSITIONAL or VAR_KEYWORD parameters!")
|
|
477
|
+
|
|
478
|
+
offset = -1 if fn.__name__ == 'construct' and args_sig_parameters[0].name == "self" else 0
|
|
479
|
+
meet_index = set()
|
|
480
|
+
|
|
481
|
+
def _check_index_valid(index):
|
|
482
|
+
if index >= len(arg_names):
|
|
483
|
+
raise ValueError("For dict mode, valid index is \"0\"-\"%d\", but got %s!" % (len(arg_names) - 1, index))
|
|
484
|
+
if index in meet_index:
|
|
485
|
+
raise ValueError("For dict mode, there are more than one same specified key for real index: %d!" % index)
|
|
486
|
+
meet_index.add(index)
|
|
487
|
+
|
|
488
|
+
arg_handler_infos = []
|
|
489
|
+
for k, v in dyn_args.items():
|
|
490
|
+
if not isinstance(k, str):
|
|
491
|
+
raise TypeError("For dict mode, only string key is accepted, but got %s!" % k)
|
|
492
|
+
if k in arg_names:
|
|
493
|
+
cur_id = arg_names.index(k)
|
|
494
|
+
_check_index_valid(cur_id)
|
|
495
|
+
arg_handler_infos.append([cur_id + offset, v])
|
|
496
|
+
else:
|
|
497
|
+
raise ValueError("For dict mode, valid key is %s, but got %s!" % (arg_names, k))
|
|
498
|
+
return {ARG_SPECIFIED: arg_handler_infos, TOTAL_ARG_LEN: len(args_sig_parameters)}
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _generate_dyn_compile_args(compile_args, dyn_args):
|
|
502
|
+
"""Generate the dynamic compile arguments."""
|
|
503
|
+
if not dyn_args:
|
|
504
|
+
return compile_args
|
|
505
|
+
if _is_args_fullmode(dyn_args, False):
|
|
506
|
+
if not isinstance(dyn_args, (list, tuple)):
|
|
507
|
+
return (dyn_args,)
|
|
508
|
+
return dyn_args
|
|
509
|
+
arg_specified_infos = dyn_args.get(ARG_SPECIFIED, None)
|
|
510
|
+
if arg_specified_infos is None:
|
|
511
|
+
raise RuntimeError("For dict mode, a key with \"%s\" should exist, but got %s!" %
|
|
512
|
+
(ARG_SPECIFIED, str(dyn_args)))
|
|
513
|
+
new_compile_args = list(compile_args)
|
|
514
|
+
for index, arg in arg_specified_infos:
|
|
515
|
+
new_compile_args[index] = arg
|
|
516
|
+
return tuple(new_compile_args)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _get_parameter_ids(args, kwargs):
|
|
520
|
+
"""Get the ids of parameters."""
|
|
521
|
+
parameter_ids = ""
|
|
522
|
+
for arg in args:
|
|
523
|
+
if isinstance(arg, Parameter):
|
|
524
|
+
parameter_ids += str(id(arg))
|
|
525
|
+
for _, value in kwargs.items():
|
|
526
|
+
# The type of key is usually String type.
|
|
527
|
+
if isinstance(value, Parameter):
|
|
528
|
+
parameter_ids += str(id(value))
|
|
529
|
+
return parameter_ids
|
|
530
|
+
|
|
531
|
+
|
|
301
532
|
class _MindsporeFunctionExecutor:
|
|
302
533
|
"""
|
|
303
534
|
Represents a function compiled by graph compiler.
|
|
@@ -315,6 +546,7 @@ class _MindsporeFunctionExecutor:
|
|
|
315
546
|
Returns:
|
|
316
547
|
The result of pipeline running in graph mode.
|
|
317
548
|
"""
|
|
549
|
+
|
|
318
550
|
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None):
|
|
319
551
|
init_pipeline()
|
|
320
552
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
@@ -329,9 +561,9 @@ class _MindsporeFunctionExecutor:
|
|
|
329
561
|
self.enable_tuple_broaden = False
|
|
330
562
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
331
563
|
self._create_time = ms_create_time
|
|
564
|
+
self._compile_args = None
|
|
332
565
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
333
566
|
|
|
334
|
-
|
|
335
567
|
@_wrap_func
|
|
336
568
|
def __call__(self, *args, **kwargs):
|
|
337
569
|
args_list = args
|
|
@@ -359,7 +591,6 @@ class _MindsporeFunctionExecutor:
|
|
|
359
591
|
|
|
360
592
|
return output
|
|
361
593
|
|
|
362
|
-
|
|
363
594
|
def compile(self, method_name, *args, **kwargs):
|
|
364
595
|
"""Returns pipeline for the given args."""
|
|
365
596
|
# Check whether hook function registered on Cell object.
|
|
@@ -376,6 +607,7 @@ class _MindsporeFunctionExecutor:
|
|
|
376
607
|
|
|
377
608
|
# Restore the mutable attr for every arg.
|
|
378
609
|
compile_args = _restore_mutable_attr(args, compile_args)
|
|
610
|
+
self._compile_args = compile_args
|
|
379
611
|
generate_name, echo_function_name = self._get_generate_name()
|
|
380
612
|
# The full Function name
|
|
381
613
|
full_function_name = generate_name
|
|
@@ -409,14 +641,21 @@ class _MindsporeFunctionExecutor:
|
|
|
409
641
|
|
|
410
642
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
411
643
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
644
|
+
|
|
645
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
646
|
+
if parameter_ids != "":
|
|
647
|
+
key = str(key) + '.' + parameter_ids
|
|
412
648
|
phase = generate_name + '.' + str(key)
|
|
413
649
|
|
|
414
650
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
415
651
|
|
|
416
652
|
if phase in ms_compile_cache:
|
|
653
|
+
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
654
|
+
# generated in generate_arguments_key.
|
|
655
|
+
self._graph_executor.clear_compile_arguments_resource()
|
|
417
656
|
return phase
|
|
418
657
|
|
|
419
|
-
self.
|
|
658
|
+
_check_recompile(self.obj, compile_args, kwargs, full_function_name, create_time, echo_function_name)
|
|
420
659
|
|
|
421
660
|
# If enable compile cache, get the dependency files list and set to graph executor.
|
|
422
661
|
self._set_compile_cache_dep_files()
|
|
@@ -448,29 +687,6 @@ class _MindsporeFunctionExecutor:
|
|
|
448
687
|
|
|
449
688
|
return phase
|
|
450
689
|
|
|
451
|
-
def _check_recompile(self, full_function_name, create_time, echo_function_name):
|
|
452
|
-
"""Warning when the function has been compiled."""
|
|
453
|
-
ignore_dirs = ["mindspore/ops", "mindspore/nn"]
|
|
454
|
-
if any((lambda x: x in full_function_name)(x) for x in ignore_dirs):
|
|
455
|
-
return
|
|
456
|
-
|
|
457
|
-
if full_function_name in function_phases:
|
|
458
|
-
warning_times = 1
|
|
459
|
-
if len(function_phases[full_function_name]) >= warning_times \
|
|
460
|
-
and create_time not in function_phases[full_function_name]:
|
|
461
|
-
tips = "Try to decorate the function with @jit(hash_args=...) " \
|
|
462
|
-
"or @jit(compile_once=True) to reduce the compile time. " \
|
|
463
|
-
"For more details, get instructions about `jit` at " \
|
|
464
|
-
"https://www.mindspore.cn/search?inputValue=jit."
|
|
465
|
-
|
|
466
|
-
logger.warning(f"The {echo_function_name} has been compiled again. "
|
|
467
|
-
f"{tips} ")
|
|
468
|
-
else:
|
|
469
|
-
function_phases[full_function_name] = set()
|
|
470
|
-
|
|
471
|
-
function_phases[full_function_name].add(create_time)
|
|
472
|
-
|
|
473
|
-
|
|
474
690
|
@staticmethod
|
|
475
691
|
def _optimizer_state_init(opt_states):
|
|
476
692
|
"""set data for all optimizer states in case it is executed in graph mode"""
|
|
@@ -481,7 +697,6 @@ class _MindsporeFunctionExecutor:
|
|
|
481
697
|
if opt_param.has_init and (prefix in prefix_list or opt_param.name == "global_step"):
|
|
482
698
|
opt_param.init_data()
|
|
483
699
|
|
|
484
|
-
|
|
485
700
|
def _get_key_id(self):
|
|
486
701
|
"""get key id."""
|
|
487
702
|
if isinstance(self.obj, ms.nn.Cell):
|
|
@@ -493,7 +708,6 @@ class _MindsporeFunctionExecutor:
|
|
|
493
708
|
key_id = key_id + ".grad"
|
|
494
709
|
return key_id
|
|
495
710
|
|
|
496
|
-
|
|
497
711
|
def _get_generate_name(self):
|
|
498
712
|
"""get generate name."""
|
|
499
713
|
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + str(
|
|
@@ -506,54 +720,47 @@ class _MindsporeFunctionExecutor:
|
|
|
506
720
|
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
|
507
721
|
return generate_name, echo_function_name
|
|
508
722
|
|
|
509
|
-
|
|
510
723
|
def _set_compile_cache_dep_files(self):
|
|
511
724
|
# If enable compile cache, get the dependency files list
|
|
512
725
|
enable_compile_cache = context.get_context("enable_compile_cache")
|
|
513
|
-
if enable_compile_cache is
|
|
726
|
+
if enable_compile_cache is None:
|
|
514
727
|
enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
|
|
515
728
|
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
516
729
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
517
730
|
|
|
518
|
-
|
|
519
731
|
def _generate_compile_args(self, args_list):
|
|
520
732
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
521
733
|
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
522
734
|
compile_args = _pynative_executor.get_dynamic_input(args_list)
|
|
523
735
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
524
736
|
if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
|
525
|
-
compile_args = self.obj.get_inputs()
|
|
737
|
+
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
526
738
|
if len(compile_args) != len(args_list):
|
|
527
739
|
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
528
740
|
f"dynamic shape tensors: {len(compile_args)}.")
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
Validator.check_dynamic_shape(compile_args[i], args_list[i], i)
|
|
741
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
742
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
532
743
|
|
|
533
744
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
534
745
|
if self.input_signature is not None:
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
dyn_shape = False
|
|
539
|
-
for i, elem in enumerate(self.input_signature):
|
|
540
|
-
if isinstance(elem, PythonTensor) and is_shape_unknown(elem.shape):
|
|
541
|
-
Validator.check_dynamic_shape(self.input_signature[i], args_list[i], i)
|
|
542
|
-
dyn_shape = True
|
|
746
|
+
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
747
|
+
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
748
|
+
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
543
749
|
if dyn_shape:
|
|
544
750
|
# Checkout whether the `sens` has been added to args_list.
|
|
545
|
-
if len(
|
|
751
|
+
if len(compile_args) == len(args_list) - 1:
|
|
546
752
|
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
547
|
-
f"of input_signature args '{len(
|
|
753
|
+
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
548
754
|
f"be 'sens' and added it to compile args.")
|
|
549
|
-
|
|
550
|
-
compile_args = tuple(
|
|
755
|
+
compile_args.append(args_list[-1])
|
|
756
|
+
compile_args = tuple(compile_args)
|
|
757
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
551
758
|
if self.obj is not None:
|
|
552
759
|
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
553
760
|
else:
|
|
554
761
|
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
555
762
|
else:
|
|
556
|
-
if not verify_inputs_signature(
|
|
763
|
+
if not verify_inputs_signature(compile_args, args_list):
|
|
557
764
|
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
558
765
|
return compile_args
|
|
559
766
|
|
|
@@ -568,7 +775,7 @@ class _MindsporeFunctionExecutor:
|
|
|
568
775
|
Returns:
|
|
569
776
|
new_inputs, new input args, which are required for running.
|
|
570
777
|
"""
|
|
571
|
-
return _get_args_for_run(self, args_list, kwargs)
|
|
778
|
+
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
572
779
|
|
|
573
780
|
|
|
574
781
|
# The attributes used to identify a given object.
|
|
@@ -596,29 +803,49 @@ def _get_jit_hash(hash_input):
|
|
|
596
803
|
return _get_obj_id(hash_input)
|
|
597
804
|
|
|
598
805
|
|
|
599
|
-
def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_once=False):
|
|
806
|
+
def jit(fn=None, mode="PSJit", input_signature=None, hash_args=None, jit_config=None, compile_once=False):
|
|
600
807
|
"""
|
|
601
808
|
Create a callable MindSpore graph from a Python function.
|
|
602
809
|
|
|
603
810
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
604
811
|
|
|
605
812
|
Note:
|
|
606
|
-
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
607
|
-
|
|
813
|
+
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
814
|
+
will not accept `**kwargs`.
|
|
815
|
+
- It is not supported to run a function with decoration @jit(mode=“PIJit”)
|
|
816
|
+
in static graph mode, in which case the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
817
|
+
- Calls to functions with decorated @jit(mode=“PIJit”) inside functions
|
|
818
|
+
decorated with @jit(mode=“PIJit”) are not supported,
|
|
819
|
+
and the decoration @jit(mode=“PIJit”) is considered invalid.
|
|
608
820
|
|
|
609
821
|
Args:
|
|
610
822
|
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
823
|
+
mode (str): The type of jit used, the value of mode should be ``PIJit`` or ``PSJit``. Default: ``PSJit`` .
|
|
824
|
+
|
|
825
|
+
- `PSJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_ :
|
|
826
|
+
Parse python ast to build graph.
|
|
827
|
+
- `PIJit <https://www.mindspore.cn/docs/en/master/model_train/program_form/pynative.html#pijit>`_ :
|
|
828
|
+
Parse python bytecode to build graph at runtime.
|
|
829
|
+
|
|
830
|
+
input_signature (Union[Tuple, List, Dict, Tensor]): The Tensor which describes the input arguments. The
|
|
831
|
+
shape and dtype of the Tensor will be supplied to this function. If `input_signature` is specified, the
|
|
832
|
+
input parameters of `fn` cannot accept `**kwargs`, and the shape and dtype of actual inputs should keep the
|
|
833
|
+
same as `input_signature`. Otherwise, TypeError will be raised. There are two mode for `input_signature`:
|
|
834
|
+
|
|
835
|
+
- Full mode: Arguments is a Tuple, List or a Tensor, and they will be used as all compile inputs
|
|
836
|
+
for graph-compiling.
|
|
837
|
+
- Incremental mode: Argument is a Dict, and they will set to some of the graph inputs, which will be
|
|
838
|
+
substituted into the input at the corresponding position for graph-compiling.
|
|
839
|
+
|
|
840
|
+
Default: ``None`` .
|
|
841
|
+
|
|
615
842
|
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
|
616
843
|
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
617
844
|
will trigger recompilation. Default: ``None`` .
|
|
618
845
|
jit_config (JitConfig): Jit config for compile. Default: ``None`` .
|
|
619
846
|
compile_once(bool): ``True``: The function would be compiled once when it was created many times.
|
|
620
847
|
But it may be wrong if the free variables were changed. ``False`` : It would be recompiled when
|
|
621
|
-
it was created again
|
|
848
|
+
it was created again.
|
|
622
849
|
Default: ``False`` .
|
|
623
850
|
|
|
624
851
|
Returns:
|
|
@@ -663,6 +890,13 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_
|
|
|
663
890
|
...
|
|
664
891
|
>>> out = tensor_add_with_sig(x, y)
|
|
665
892
|
...
|
|
893
|
+
>>> @jit(input_signature={"y": Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))})
|
|
894
|
+
... def tensor_add_with_sig_1(x, y):
|
|
895
|
+
... z = x + y
|
|
896
|
+
... return z
|
|
897
|
+
...
|
|
898
|
+
>>> out1 = tensor_add_with_sig_1(x, y)
|
|
899
|
+
...
|
|
666
900
|
... # Set hash_args as fn, otherwise cache of compiled closure_fn will not be reused.
|
|
667
901
|
... # While fn differs during calling again, recompilation will be triggered.
|
|
668
902
|
>>> def func(x):
|
|
@@ -702,6 +936,8 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_
|
|
|
702
936
|
else:
|
|
703
937
|
hash_obj = int(time.time() * 1e9)
|
|
704
938
|
|
|
939
|
+
dyn_args = _process_dyn_args(func, input_signature)
|
|
940
|
+
|
|
705
941
|
@wraps(func)
|
|
706
942
|
def staging_specialize(*args, **kwargs):
|
|
707
943
|
if os.getenv("MS_JIT") == '0':
|
|
@@ -715,14 +951,24 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, compile_
|
|
|
715
951
|
# only the function or cell instance wrapped by shard will fall into this branch
|
|
716
952
|
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
717
953
|
process_obj = hash_args
|
|
718
|
-
|
|
954
|
+
# Handle auto mixed precision strategy.
|
|
955
|
+
if not hasattr(func, "amp_strategy"):
|
|
956
|
+
if isinstance(func, types.MethodType):
|
|
957
|
+
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
958
|
+
else:
|
|
959
|
+
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
960
|
+
out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
|
|
719
961
|
return out
|
|
720
962
|
|
|
721
963
|
return staging_specialize
|
|
722
964
|
|
|
965
|
+
wrap_func = wrap_mindspore
|
|
966
|
+
if mode == "PIJit":
|
|
967
|
+
wrap_func = PIJitCaptureContext(jit_config, input_signature)
|
|
968
|
+
|
|
723
969
|
if fn is not None:
|
|
724
|
-
return
|
|
725
|
-
return
|
|
970
|
+
return wrap_func(fn)
|
|
971
|
+
return wrap_func
|
|
726
972
|
|
|
727
973
|
|
|
728
974
|
def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
@@ -732,15 +978,14 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
732
978
|
This allows the MindSpore runtime to apply optimizations based on graph.
|
|
733
979
|
|
|
734
980
|
Note:
|
|
735
|
-
`ms_function` will be deprecated and removed in a future version. Please use
|
|
736
|
-
If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
737
|
-
|
|
981
|
+
- `ms_function` will be deprecated and removed in a future version. Please use :func:`mindspore.jit` instead.
|
|
982
|
+
- If `input_signature` is specified, each input of `fn` must be a Tensor. And the input arguments for `fn`
|
|
983
|
+
will not accept `**kwargs`.
|
|
738
984
|
|
|
739
985
|
Args:
|
|
740
986
|
fn (Function): The Python function that will be run as a graph. Default: ``None`` .
|
|
741
987
|
input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
|
|
742
|
-
will be supplied to this function.
|
|
743
|
-
And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
|
|
988
|
+
will be supplied to this function. The shape and dtype of actual inputs of `fn` should
|
|
744
989
|
keep the same as input_signature. Otherwise, TypeError will be raised. Default: ``None`` .
|
|
745
990
|
hash_args (Union[Object, List or Tuple of Objects]): The local free variables used inside `fn`,
|
|
746
991
|
like functions or objects of class defined outside `fn`. Calling `fn` again with change of `hash_args`
|
|
@@ -909,7 +1154,7 @@ def ms_class(cls):
|
|
|
909
1154
|
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
|
910
1155
|
|
|
911
1156
|
Note:
|
|
912
|
-
`ms_class` will be deprecated and removed in a future version. Please use
|
|
1157
|
+
`ms_class` will be deprecated and removed in a future version. Please use :func:`mindspore.jit_class` instead.
|
|
913
1158
|
|
|
914
1159
|
Args:
|
|
915
1160
|
cls (Class): User-defined class.
|
|
@@ -1015,7 +1260,7 @@ def jit_class(cls):
|
|
|
1015
1260
|
if not inspect.isclass(cls):
|
|
1016
1261
|
raise TypeError(f'Decorator jit_class can only be used for class type, but got {cls}.')
|
|
1017
1262
|
# Check if cls is nn.Cell.
|
|
1018
|
-
if issubclass(cls, nn.Cell):
|
|
1263
|
+
if issubclass(cls, nn.cell.Cell):
|
|
1019
1264
|
raise TypeError(f"Decorator jit_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
|
1020
1265
|
setattr(cls, '__ms_class__', True)
|
|
1021
1266
|
return cls
|
|
@@ -1037,6 +1282,8 @@ def set_adapter_config(config):
|
|
|
1037
1282
|
ms_adapter_registry.register_parameter(value)
|
|
1038
1283
|
elif key == "convert_object_map":
|
|
1039
1284
|
ms_adapter_registry.register_convert_map(value)
|
|
1285
|
+
elif key == "convert_adapter_tensor_map":
|
|
1286
|
+
ms_adapter_registry.register_convert_adapter_tensor_map(value)
|
|
1040
1287
|
else:
|
|
1041
1288
|
raise ValueError(f"Unsupported key in adapter config: {key}")
|
|
1042
1289
|
|
|
@@ -1135,16 +1382,6 @@ class _PyNativeExecutor:
|
|
|
1135
1382
|
self._executor = PyNativeExecutor_.get_instance()
|
|
1136
1383
|
self._executor.set_py_exe_path(sys.executable)
|
|
1137
1384
|
self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
1138
|
-
self._top_cell = None
|
|
1139
|
-
|
|
1140
|
-
def __call__(self):
|
|
1141
|
-
"""
|
|
1142
|
-
PyNative executor run grad graph.
|
|
1143
|
-
|
|
1144
|
-
Return:
|
|
1145
|
-
The return object after running grad graph.
|
|
1146
|
-
"""
|
|
1147
|
-
return self._executor()
|
|
1148
1385
|
|
|
1149
1386
|
@staticmethod
|
|
1150
1387
|
def parameter_broadcast(obj, phase):
|
|
@@ -1214,23 +1451,22 @@ class _PyNativeExecutor:
|
|
|
1214
1451
|
"""
|
|
1215
1452
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1216
1453
|
|
|
1217
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args
|
|
1454
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1218
1455
|
"""
|
|
1219
1456
|
Whether the forward graph need to construct.
|
|
1220
1457
|
|
|
1221
1458
|
Args:
|
|
1222
1459
|
grad (GradOperation): The gradoperation object.
|
|
1223
1460
|
obj (Function/Cell): The function or cell instance.
|
|
1224
|
-
grad_hash_id (tuple): The id of objects which
|
|
1461
|
+
grad_hash_id (tuple): The id of objects, which contributes to cache of compiled graph in pynative mode.
|
|
1225
1462
|
args (tuple): Function or cell input arguments.
|
|
1226
|
-
kwargs (dict): keyword arguments.
|
|
1227
1463
|
|
|
1228
1464
|
Return:
|
|
1229
|
-
bool, specifies whether the forward graph
|
|
1465
|
+
bool, specifies whether the forward graph needs to construct.
|
|
1230
1466
|
"""
|
|
1231
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args
|
|
1467
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1232
1468
|
|
|
1233
|
-
def grad(self, obj, grad, weights, grad_position, *args
|
|
1469
|
+
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1234
1470
|
"""
|
|
1235
1471
|
Get grad graph.
|
|
1236
1472
|
|
|
@@ -1241,12 +1477,11 @@ class _PyNativeExecutor:
|
|
|
1241
1477
|
grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
|
|
1242
1478
|
If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
|
|
1243
1479
|
args (tuple): Function or cell input arguments.
|
|
1244
|
-
kwargs (dict): keyword arguments.
|
|
1245
1480
|
|
|
1246
1481
|
Return:
|
|
1247
1482
|
None.
|
|
1248
1483
|
"""
|
|
1249
|
-
self._executor.
|
|
1484
|
+
return self._executor.grad(grad, obj, weights, grad_position, *args)
|
|
1250
1485
|
|
|
1251
1486
|
def clear_res(self):
|
|
1252
1487
|
"""
|
|
@@ -1279,9 +1514,23 @@ class _PyNativeExecutor:
|
|
|
1279
1514
|
"""
|
|
1280
1515
|
return self._executor.grad_jit(output, *args)
|
|
1281
1516
|
|
|
1517
|
+
def call_custom_bprop(self, obj, output, *args, **kwargs):
|
|
1518
|
+
"""
|
|
1519
|
+
Call custom bprop to build variable for cell bprop.
|
|
1520
|
+
Args:
|
|
1521
|
+
obj (Cell): The function or cell instance.
|
|
1522
|
+
output (Tensor/tuple/list): Function or cell output object.
|
|
1523
|
+
args (tuple): Function or cell input arguments.
|
|
1524
|
+
kwargs (dict): keyword arguments.
|
|
1525
|
+
|
|
1526
|
+
Return:
|
|
1527
|
+
None.
|
|
1528
|
+
"""
|
|
1529
|
+
return self._executor.call_custom_bprop(obj, output, *args, *(kwargs.values()))
|
|
1530
|
+
|
|
1282
1531
|
def grad_flag(self):
|
|
1283
1532
|
"""
|
|
1284
|
-
The flag of building grad graph.
|
|
1533
|
+
The flag of whether the net building grad graph.
|
|
1285
1534
|
|
|
1286
1535
|
Return:
|
|
1287
1536
|
bool, whether building grad graph.
|
|
@@ -1300,9 +1549,21 @@ class _PyNativeExecutor:
|
|
|
1300
1549
|
"""
|
|
1301
1550
|
self._executor.set_grad_flag(flag)
|
|
1302
1551
|
|
|
1552
|
+
def set_async_for_graph(self, flag):
|
|
1553
|
+
"""
|
|
1554
|
+
Set the flag for graph async run.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
flag (bool): Specifying whether enable graph async run.
|
|
1558
|
+
|
|
1559
|
+
Return:
|
|
1560
|
+
None.
|
|
1561
|
+
"""
|
|
1562
|
+
self._executor.set_async_for_graph(flag)
|
|
1563
|
+
|
|
1303
1564
|
def enable_grad(self):
|
|
1304
1565
|
"""
|
|
1305
|
-
The global flag whether
|
|
1566
|
+
The global flag that whether need to calculate gradient use in no_grad.
|
|
1306
1567
|
|
|
1307
1568
|
Return:
|
|
1308
1569
|
bool, whether needing to calculate gradient.
|
|
@@ -1321,6 +1582,18 @@ class _PyNativeExecutor:
|
|
|
1321
1582
|
"""
|
|
1322
1583
|
self._executor.set_enable_grad(flag)
|
|
1323
1584
|
|
|
1585
|
+
def requires_grad(self):
|
|
1586
|
+
"""
|
|
1587
|
+
When both enable_grad is true and grad_flag is true, that the flag requires_grad will be true.
|
|
1588
|
+
|
|
1589
|
+
Args:
|
|
1590
|
+
flag (bool): Specifying whether calculating gradient.
|
|
1591
|
+
|
|
1592
|
+
Return:
|
|
1593
|
+
None.
|
|
1594
|
+
"""
|
|
1595
|
+
return self._executor.requires_grad()
|
|
1596
|
+
|
|
1324
1597
|
def set_jit_compile_status(self, status, phase):
|
|
1325
1598
|
"""
|
|
1326
1599
|
Set jit is compiling
|
|
@@ -1333,6 +1606,29 @@ class _PyNativeExecutor:
|
|
|
1333
1606
|
"""
|
|
1334
1607
|
self._executor.set_jit_compile_status(status, phase)
|
|
1335
1608
|
|
|
1609
|
+
def set_is_run_recompute(self, status):
|
|
1610
|
+
"""
|
|
1611
|
+
Set recompute grad is compiling
|
|
1612
|
+
|
|
1613
|
+
Args:
|
|
1614
|
+
status(bool): grad is in recompute status
|
|
1615
|
+
Return:
|
|
1616
|
+
None.
|
|
1617
|
+
"""
|
|
1618
|
+
self._executor.set_is_run_recompute(status)
|
|
1619
|
+
|
|
1620
|
+
def set_cell_use_dynamic_shape_process(self, flag):
|
|
1621
|
+
"""
|
|
1622
|
+
Set the dynamic shape flag of eval process.
|
|
1623
|
+
|
|
1624
|
+
Args:
|
|
1625
|
+
flag (bool): Specifying whether using a dynamic process.
|
|
1626
|
+
|
|
1627
|
+
Return:
|
|
1628
|
+
None.
|
|
1629
|
+
"""
|
|
1630
|
+
self._executor.set_cell_use_dynamic_shape_process(flag)
|
|
1631
|
+
|
|
1336
1632
|
def set_dynamic_input(self, obj, *args):
|
|
1337
1633
|
"""
|
|
1338
1634
|
Set dynamic shape tensor of input arguments.
|
|
@@ -1358,36 +1654,19 @@ class _PyNativeExecutor:
|
|
|
1358
1654
|
"""
|
|
1359
1655
|
return self._executor.get_dynamic_input(*actual_args)
|
|
1360
1656
|
|
|
1361
|
-
def
|
|
1362
|
-
"""
|
|
1363
|
-
The flag of first cell instance.
|
|
1364
|
-
|
|
1365
|
-
Return:
|
|
1366
|
-
bool, specifies whether is the first cell.
|
|
1367
|
-
"""
|
|
1368
|
-
|
|
1369
|
-
return self._executor.is_first_cell()
|
|
1370
|
-
|
|
1371
|
-
def set_hook_changed(self, cell):
|
|
1657
|
+
def set_mixed_precision_type(self, mixed_precision_type, is_push=True):
|
|
1372
1658
|
"""
|
|
1373
|
-
The
|
|
1659
|
+
The value of mixed precision type.
|
|
1374
1660
|
|
|
1375
1661
|
Args:
|
|
1376
|
-
|
|
1662
|
+
type(MixedPrecisionType): Mix precision type.
|
|
1663
|
+
is_push(bool): If called by __enter__, is push will be True
|
|
1377
1664
|
|
|
1378
1665
|
Return:
|
|
1379
1666
|
None.
|
|
1380
1667
|
"""
|
|
1381
|
-
self._executor.set_hook_changed(cell)
|
|
1382
|
-
|
|
1383
|
-
def get_top_cell(self):
|
|
1384
|
-
"""
|
|
1385
|
-
Get the top cell object.
|
|
1386
1668
|
|
|
1387
|
-
|
|
1388
|
-
The top cell object.
|
|
1389
|
-
"""
|
|
1390
|
-
return self._top_cell
|
|
1669
|
+
return self._executor.set_mixed_precision_type(mixed_precision_type, is_push)
|
|
1391
1670
|
|
|
1392
1671
|
def constant_folding(self, *args):
|
|
1393
1672
|
"""
|
|
@@ -1424,6 +1703,7 @@ class _CellGraphExecutor:
|
|
|
1424
1703
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
1425
1704
|
self._graph_executor.set_py_exe_path(sys.executable)
|
|
1426
1705
|
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
1706
|
+
self._pid = os.getpid()
|
|
1427
1707
|
|
|
1428
1708
|
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
|
1429
1709
|
input_indexs, phase='dataset', need_run=True):
|
|
@@ -1491,9 +1771,9 @@ class _CellGraphExecutor:
|
|
|
1491
1771
|
def _set_compile_cache_dep_files(self, phase):
|
|
1492
1772
|
# If enable compile cache, get the dependency files list
|
|
1493
1773
|
enable_compile_cache = context.get_context("enable_compile_cache")
|
|
1494
|
-
if enable_compile_cache is
|
|
1774
|
+
if enable_compile_cache is None:
|
|
1495
1775
|
enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
|
|
1496
|
-
if
|
|
1776
|
+
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
1497
1777
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
1498
1778
|
|
|
1499
1779
|
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
|
|
@@ -1522,17 +1802,30 @@ class _CellGraphExecutor:
|
|
|
1522
1802
|
self.enable_tuple_broaden = False
|
|
1523
1803
|
if hasattr(obj, "enable_tuple_broaden"):
|
|
1524
1804
|
self.enable_tuple_broaden = obj.enable_tuple_broaden
|
|
1525
|
-
logger.debug("Convert the network."
|
|
1805
|
+
logger.debug(f"Convert the network: {do_convert}.")
|
|
1526
1806
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1527
1807
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1528
1808
|
obj.arguments_key = str(key)
|
|
1809
|
+
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
1810
|
+
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
1811
|
+
if parameter_ids != "":
|
|
1812
|
+
obj.arguments_key = obj.arguments_key + '.' + parameter_ids
|
|
1813
|
+
raw_phase = phase
|
|
1529
1814
|
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
1815
|
+
obj.phase_cache[raw_phase] = phase
|
|
1530
1816
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
1531
|
-
|
|
1817
|
+
obj.current_phase = phase
|
|
1532
1818
|
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
1533
1819
|
logger.debug("%r graph has existed.", phase)
|
|
1820
|
+
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
1821
|
+
# generated in generate_arguments_key.
|
|
1822
|
+
self._graph_executor.clear_compile_arguments_resource()
|
|
1534
1823
|
return phase, False
|
|
1535
1824
|
|
|
1825
|
+
full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
|
|
1826
|
+
echo_function_name = obj.__class__.__name__
|
|
1827
|
+
_check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
|
|
1828
|
+
|
|
1536
1829
|
obj.check_names()
|
|
1537
1830
|
_check_full_batch()
|
|
1538
1831
|
self._set_dataset_mode(obj)
|
|
@@ -1553,14 +1846,13 @@ class _CellGraphExecutor:
|
|
|
1553
1846
|
if graph is None:
|
|
1554
1847
|
raise RuntimeError("Compile graph failed for phase {}.".format(phase))
|
|
1555
1848
|
|
|
1556
|
-
auto_parallel_mode = _is_in_auto_parallel_mode()
|
|
1849
|
+
auto_parallel_mode = _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
1557
1850
|
if not auto_parallel_mode:
|
|
1558
1851
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
|
1559
1852
|
self._update_param_node_default_input(phase, replace)
|
|
1560
1853
|
elif 'skip_auto_parallel_compile' not in obj.get_flags().keys():
|
|
1561
1854
|
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
|
|
1562
1855
|
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
|
|
1563
|
-
|
|
1564
1856
|
if "export.air" in phase:
|
|
1565
1857
|
self._build_data_graph(obj, phase)
|
|
1566
1858
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
@@ -1600,6 +1892,18 @@ class _CellGraphExecutor:
|
|
|
1600
1892
|
"""
|
|
1601
1893
|
return self._graph_executor.has_compiled(phase)
|
|
1602
1894
|
|
|
1895
|
+
def flops_collection(self, phase='train'):
|
|
1896
|
+
"""
|
|
1897
|
+
Specify whether have been compiled.
|
|
1898
|
+
|
|
1899
|
+
Args:
|
|
1900
|
+
phase (str): The phase name. Default: 'predict'.
|
|
1901
|
+
|
|
1902
|
+
Returns:
|
|
1903
|
+
bool, specifies whether the specific graph has been compiled.
|
|
1904
|
+
"""
|
|
1905
|
+
return self._graph_executor.flops_collection(phase)
|
|
1906
|
+
|
|
1603
1907
|
@_wrap_func
|
|
1604
1908
|
def _exec_pip(self, obj, *args, phase=''):
|
|
1605
1909
|
"""Execute the generated pipeline."""
|
|
@@ -1630,7 +1934,9 @@ class _CellGraphExecutor:
|
|
|
1630
1934
|
|
|
1631
1935
|
def del_net_res(self, obj, net_id):
|
|
1632
1936
|
"""Clear the memory resource of a network."""
|
|
1633
|
-
|
|
1937
|
+
# no need to del net res by gc in independent dataset process which is a subprocess forked by main process
|
|
1938
|
+
if self._pid == os.getpid():
|
|
1939
|
+
self._graph_executor.del_net_res(obj, net_id)
|
|
1634
1940
|
|
|
1635
1941
|
def _get_branch_control_input(self):
|
|
1636
1942
|
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
|
@@ -1738,7 +2044,21 @@ def _bind_device_context():
|
|
|
1738
2044
|
_bind_device_ctx()
|
|
1739
2045
|
|
|
1740
2046
|
|
|
2047
|
+
def flops_collection(phase='train'):
|
|
2048
|
+
"""
|
|
2049
|
+
Recycle memory used by MindSpore.
|
|
2050
|
+
When train multi Neural network models in one process, memory used by MindSpore is very large,
|
|
2051
|
+
this is because MindSpore cached runtime memory for every model.
|
|
2052
|
+
To recycle these cached memory, users can call this function after training of one model.
|
|
2053
|
+
|
|
2054
|
+
Examples:
|
|
2055
|
+
>>> import mindspore as ms
|
|
2056
|
+
>>> ms.ms_memory_recycle()
|
|
2057
|
+
"""
|
|
2058
|
+
return _cell_graph_executor.flops_collection(phase)
|
|
2059
|
+
|
|
2060
|
+
|
|
1741
2061
|
_cell_graph_executor = _CellGraphExecutor()
|
|
1742
2062
|
_pynative_executor = _PyNativeExecutor()
|
|
1743
2063
|
|
|
1744
|
-
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class']
|
|
2064
|
+
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class', 'jit', 'jit_class', 'flops_collection']
|