mindspore 2.2.14__cp38-cp38-manylinux1_x86_64.whl → 2.3.0rc2__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +18 -8
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +78 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +127 -206
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _profiler.py} +13 -16
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +7 -3
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +145 -50
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +9 -6
- mindspore/common/dump.py +5 -4
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +33 -13
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +24 -4
- mindspore/common/recompute.py +247 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +391 -465
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +176 -55
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +12 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6318 -1760
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +131 -174
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +113 -81
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +14 -18
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +101 -57
- mindspore/ops/_vmap/vmap_nn_ops.py +230 -97
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +205 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +257 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +171 -0
- mindspore/ops/auto_generate/gen_extend_func.py +404 -0
- mindspore/ops/auto_generate/gen_ops_def.py +5653 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +11623 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +359 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +168 -602
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +24 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +54 -0
- mindspore/ops/extend/array_func.py +259 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +384 -0
- mindspore/ops/function/__init__.py +37 -12
- mindspore/ops/function/array_func.py +702 -1867
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +639 -2531
- mindspore/ops/function/nn_func.py +1274 -832
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +84 -71
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +57 -63
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +19 -20
- mindspore/ops/operations/_grad_ops.py +20 -828
- mindspore/ops/operations/_inner_ops.py +180 -288
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/array_ops.py +83 -2697
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +14 -96
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1716 -0
- mindspore/ops/operations/math_ops.py +581 -4629
- mindspore/ops/operations/nn_ops.py +260 -1941
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +257 -0
- mindspore/ops_generate/arg_handler.py +171 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +939 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +349 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +6 -4
- mindspore/parallel/_auto_parallel_context.py +52 -2
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +29 -13
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +147 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -24
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +109 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +294 -133
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +305 -167
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +18 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +257 -133
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +534 -1066
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/nn/__init__.py
CHANGED
|
@@ -21,6 +21,7 @@ from __future__ import absolute_import
|
|
|
21
21
|
|
|
22
22
|
from mindspore.nn import layer, loss, optim, wrap, grad, metrics, probability, sparse, dynamic_lr, reinforcement
|
|
23
23
|
from mindspore.nn.learning_rate_schedule import *
|
|
24
|
+
from mindspore.nn.generator import *
|
|
24
25
|
from mindspore.nn.dynamic_lr import *
|
|
25
26
|
from mindspore.nn.cell import Cell, GraphCell
|
|
26
27
|
from mindspore.nn.layer import *
|
|
@@ -31,6 +32,7 @@ from mindspore.nn.wrap import *
|
|
|
31
32
|
from mindspore.nn.grad import Jvp, Vjp
|
|
32
33
|
from mindspore.nn.sparse import *
|
|
33
34
|
from mindspore.nn.reinforcement import *
|
|
35
|
+
from mindspore.nn import extend
|
|
34
36
|
|
|
35
37
|
__all__ = ["Cell", "GraphCell"]
|
|
36
38
|
__all__.extend(layer.__all__)
|
|
@@ -43,5 +45,6 @@ __all__.extend(sparse.__all__)
|
|
|
43
45
|
__all__.extend(learning_rate_schedule.__all__)
|
|
44
46
|
__all__.extend(dynamic_lr.__all__)
|
|
45
47
|
__all__.extend(reinforcement.__all__)
|
|
48
|
+
__all__.extend(generator.__all__)
|
|
46
49
|
|
|
47
50
|
__all__.sort()
|
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,10 +20,9 @@ import inspect
|
|
|
20
20
|
import os
|
|
21
21
|
import time
|
|
22
22
|
from collections import OrderedDict
|
|
23
|
-
from types import FunctionType, MethodType
|
|
24
23
|
import numpy
|
|
25
24
|
|
|
26
|
-
from mindspore._checkparam import args_type_check
|
|
25
|
+
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
27
26
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
28
27
|
from mindspore import log as logger
|
|
29
28
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
@@ -34,7 +33,7 @@ from mindspore._c_expression import init_pipeline, update_func_graph_hyper_param
|
|
|
34
33
|
from mindspore import _checkparam as Validator
|
|
35
34
|
from mindspore.common import dtype as mstype
|
|
36
35
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
36
|
+
from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
|
|
38
37
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
39
38
|
from mindspore.common.tensor import Tensor
|
|
40
39
|
from mindspore.ops.operations import Cast
|
|
@@ -43,15 +42,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
43
42
|
from mindspore.parallel.shard import Shard
|
|
44
43
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
44
|
from mindspore.common._decorator import deprecated
|
|
46
|
-
from mindspore.
|
|
47
|
-
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _check_args(args):
|
|
51
|
-
"""Check the input args's type"""
|
|
52
|
-
for item in args:
|
|
53
|
-
if isinstance(item, Tensor) and item.has_init:
|
|
54
|
-
item.init_data()
|
|
45
|
+
from mindspore.common._register_for_recompute import recompute_registry
|
|
55
46
|
|
|
56
47
|
|
|
57
48
|
class Cell(Cell_):
|
|
@@ -109,7 +100,7 @@ class Cell(Cell_):
|
|
|
109
100
|
"""
|
|
110
101
|
|
|
111
102
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
112
|
-
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '
|
|
103
|
+
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
|
|
113
104
|
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
|
114
105
|
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
115
106
|
'_attr_synced', 'pynative', 'requires_grad', 'cell_type']
|
|
@@ -119,7 +110,6 @@ class Cell(Cell_):
|
|
|
119
110
|
self._params = OrderedDict()
|
|
120
111
|
self._cells = OrderedDict()
|
|
121
112
|
self._params_list = OrderedDict()
|
|
122
|
-
self._tensor_list = OrderedDict()
|
|
123
113
|
self._primitives = OrderedDict()
|
|
124
114
|
self.training = False
|
|
125
115
|
self.requires_grad = False
|
|
@@ -135,11 +125,13 @@ class Cell(Cell_):
|
|
|
135
125
|
self._create_time = int(time.time() * 1e9)
|
|
136
126
|
self.arguments_key = ""
|
|
137
127
|
self.compile_cache = set()
|
|
128
|
+
self.phase_cache = dict()
|
|
138
129
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
139
130
|
self.parameter_broadcast_done = False
|
|
140
131
|
self._id = 1
|
|
141
132
|
self.exist_names = set("")
|
|
142
133
|
self.exist_objs = set()
|
|
134
|
+
self.recompute_cell = None
|
|
143
135
|
init_pipeline()
|
|
144
136
|
|
|
145
137
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -161,12 +153,14 @@ class Cell(Cell_):
|
|
|
161
153
|
self._has_config_recompute = False
|
|
162
154
|
self._user_parameters = []
|
|
163
155
|
self._dynamic_shape_inputs = None
|
|
156
|
+
self._compile_args = None
|
|
164
157
|
self.saved_dynamic_shape = None
|
|
165
158
|
self._jit_config_dict = dict()
|
|
166
159
|
self.grad_ops_label = False
|
|
167
160
|
self.ge_sync_data = False
|
|
168
161
|
self._is_check_and_refresh = False
|
|
169
162
|
self._amp_level = ""
|
|
163
|
+
self._init_flag = False
|
|
170
164
|
|
|
171
165
|
def __getstate__(self):
|
|
172
166
|
base = Cell_.__getstate__(self)
|
|
@@ -225,7 +219,7 @@ class Cell(Cell_):
|
|
|
225
219
|
|
|
226
220
|
Tutorial Examples:
|
|
227
221
|
- `Cell and Parameter - Custom Cell Reverse
|
|
228
|
-
<https://mindspore.cn/tutorials/en/
|
|
222
|
+
<https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
229
223
|
"""
|
|
230
224
|
return self._bprop_debug
|
|
231
225
|
|
|
@@ -317,10 +311,23 @@ class Cell(Cell_):
|
|
|
317
311
|
|
|
318
312
|
@property
|
|
319
313
|
def pipeline_stage(self):
|
|
314
|
+
"""
|
|
315
|
+
`pipeline_stage` represents the pipeline stage of current Cell.
|
|
316
|
+
"""
|
|
320
317
|
return self._pipeline_stage
|
|
321
318
|
|
|
322
319
|
@pipeline_stage.setter
|
|
323
320
|
def pipeline_stage(self, value):
|
|
321
|
+
"""
|
|
322
|
+
Set the `pipeline_stage` of a Cell.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
value (int): The pipeline stage of a parameter.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
TypeError: If `value` is not int type or is a bool type.
|
|
329
|
+
ValueError: If `value` is not a positive integer.
|
|
330
|
+
"""
|
|
324
331
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
325
332
|
raise TypeError("For 'Cell', the property 'pipeline_stage' "
|
|
326
333
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -376,10 +383,6 @@ class Cell(Cell_):
|
|
|
376
383
|
cells = self.__dict__['_cells']
|
|
377
384
|
if name in cells:
|
|
378
385
|
return cells[name]
|
|
379
|
-
if '_tensor_list' in self.__dict__:
|
|
380
|
-
tensor_list = self.__dict__['_tensor_list']
|
|
381
|
-
if name in tensor_list:
|
|
382
|
-
return tensor_list[name]
|
|
383
386
|
if '_params_list' in self.__dict__:
|
|
384
387
|
params_list = self.__dict__['_params_list']
|
|
385
388
|
if name in params_list:
|
|
@@ -391,12 +394,8 @@ class Cell(Cell_):
|
|
|
391
394
|
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
|
|
392
395
|
# here using pop(id(self), None) to avoid KeyError exception
|
|
393
396
|
cells_compile_cache.pop(id(self), None)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
397
|
-
except AttributeError as e:
|
|
398
|
-
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
399
|
-
f"Please use 'super().__init__()'.") from e
|
|
397
|
+
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
398
|
+
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
400
399
|
|
|
401
400
|
def __delattr__(self, name):
|
|
402
401
|
if name in self._params:
|
|
@@ -405,8 +404,6 @@ class Cell(Cell_):
|
|
|
405
404
|
del self._cells[name]
|
|
406
405
|
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
406
|
del self._params_list[name]
|
|
408
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
-
del self._tensor_list[name]
|
|
410
407
|
else:
|
|
411
408
|
object.__delattr__(self, name)
|
|
412
409
|
self._attr_synced = False
|
|
@@ -420,7 +417,7 @@ class Cell(Cell_):
|
|
|
420
417
|
elif isinstance(item, float):
|
|
421
418
|
res.append(self.cast(item, dst_type))
|
|
422
419
|
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
-
|
|
420
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
424
421
|
res.append(self.cast(item, dst_type))
|
|
425
422
|
else:
|
|
426
423
|
res.append(item)
|
|
@@ -479,7 +476,10 @@ class Cell(Cell_):
|
|
|
479
476
|
elif hasattr(self, "_shard_fn"):
|
|
480
477
|
output = self._shard_fn(*cast_inputs, **kwargs)
|
|
481
478
|
else:
|
|
482
|
-
|
|
479
|
+
if self.recompute_cell is not None:
|
|
480
|
+
output = self.recompute_cell(*cast_inputs, **kwargs)
|
|
481
|
+
else:
|
|
482
|
+
output = self.construct(*cast_inputs, **kwargs)
|
|
483
483
|
if self._enable_forward_hook:
|
|
484
484
|
output = self._run_forward_hook(cast_inputs, output)
|
|
485
485
|
return output
|
|
@@ -654,48 +654,56 @@ class Cell(Cell_):
|
|
|
654
654
|
|
|
655
655
|
return cast_inputs
|
|
656
656
|
|
|
657
|
-
def
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
if kwargs:
|
|
662
|
-
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
|
663
|
-
bound_arguments.apply_defaults()
|
|
664
|
-
args = bound_arguments.args
|
|
665
|
-
kwargs = bound_arguments.kwargs
|
|
657
|
+
def _init_check(self):
|
|
658
|
+
for param in self.get_parameters(expand=False):
|
|
659
|
+
if param.has_init:
|
|
660
|
+
param.init_data()
|
|
666
661
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
662
|
+
def _self_check(self):
|
|
663
|
+
if not self._is_check_and_refresh:
|
|
671
664
|
self.check_names_and_refresh_name()
|
|
672
665
|
self._is_check_and_refresh = True
|
|
673
666
|
|
|
667
|
+
def _predict(self, *args, **kwargs):
|
|
668
|
+
if not hasattr(self, "phase"):
|
|
669
|
+
return False, None
|
|
670
|
+
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
671
|
+
new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
|
|
672
|
+
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
673
|
+
res = _convert_python_data(res)
|
|
674
|
+
return True, res
|
|
675
|
+
return False, None
|
|
676
|
+
|
|
677
|
+
def __call__(self, *args, **kwargs):
|
|
674
678
|
# Run in Graph mode.
|
|
675
679
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
680
|
+
if kwargs:
|
|
681
|
+
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
|
682
|
+
bound_arguments.apply_defaults()
|
|
683
|
+
args = bound_arguments.args
|
|
684
|
+
kwargs = bound_arguments.kwargs
|
|
685
|
+
|
|
686
|
+
predict_compiled, res = self._predict(*args, **kwargs)
|
|
687
|
+
if predict_compiled:
|
|
688
|
+
return res
|
|
676
689
|
self._check_construct_args(*args)
|
|
690
|
+
|
|
677
691
|
if self._hook_fn_registered():
|
|
678
692
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
679
693
|
f"function, please use context.set_context to set pynative mode.")
|
|
694
|
+
self._self_check()
|
|
680
695
|
out = self.compile_and_run(*args, **kwargs)
|
|
681
696
|
return out
|
|
682
697
|
|
|
683
698
|
# Run in PyNative mode.
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
self._do_parameter_broadcast()
|
|
699
|
+
self._self_check()
|
|
700
|
+
if not self._init_flag:
|
|
701
|
+
self._init_check()
|
|
702
|
+
self._init_flag = True
|
|
689
703
|
|
|
690
|
-
|
|
691
|
-
self._check_cell_flags_in_pynative()
|
|
692
|
-
|
|
693
|
-
if self.requires_grad and _pynative_executor.enable_grad():
|
|
704
|
+
if self.requires_grad:
|
|
694
705
|
_pynative_executor.set_grad_flag(True)
|
|
695
706
|
|
|
696
|
-
if self._dynamic_shape_inputs is not None:
|
|
697
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
698
|
-
|
|
699
707
|
try:
|
|
700
708
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
701
709
|
output = self._run_construct(args, kwargs)
|
|
@@ -704,16 +712,8 @@ class Cell(Cell_):
|
|
|
704
712
|
_pynative_executor.clear_res()
|
|
705
713
|
raise err
|
|
706
714
|
|
|
707
|
-
if isinstance(output, Parameter):
|
|
708
|
-
output = output.data
|
|
709
715
|
return output
|
|
710
716
|
|
|
711
|
-
def _check_cell_flags_in_pynative(self):
|
|
712
|
-
"""Check the flags added to cell in pynative mode"""
|
|
713
|
-
if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
|
|
714
|
-
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
715
|
-
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
716
|
-
|
|
717
717
|
def _add_attr(self, name, value):
|
|
718
718
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
719
719
|
super(Cell, self)._add_attr(name, value)
|
|
@@ -830,15 +830,6 @@ class Cell(Cell_):
|
|
|
830
830
|
else:
|
|
831
831
|
self.insert_param_to_cell(name, None)
|
|
832
832
|
|
|
833
|
-
def _set_attr_for_tensor(self, name, value):
|
|
834
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
835
|
-
tensor_list = self.__dict__.get('_tensor_list')
|
|
836
|
-
if name in self.__dict__:
|
|
837
|
-
del self.__dict__[name]
|
|
838
|
-
tensor_list[name] = value
|
|
839
|
-
else:
|
|
840
|
-
object.__setattr__(self, name, value)
|
|
841
|
-
|
|
842
833
|
def __setattr__(self, name, value):
|
|
843
834
|
cells = self.__dict__.get('_cells')
|
|
844
835
|
params = self.__dict__.get('_params')
|
|
@@ -856,8 +847,6 @@ class Cell(Cell_):
|
|
|
856
847
|
if value is not None:
|
|
857
848
|
raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
|
|
858
849
|
self._cells[name] = None
|
|
859
|
-
elif isinstance(value, Tensor):
|
|
860
|
-
self._set_attr_for_tensor(name, value)
|
|
861
850
|
else:
|
|
862
851
|
if isinstance(value, Primitive):
|
|
863
852
|
value.set_prim_instance_name(name)
|
|
@@ -981,6 +970,20 @@ class Cell(Cell_):
|
|
|
981
970
|
|
|
982
971
|
return self._dynamic_shape_inputs
|
|
983
972
|
|
|
973
|
+
def _get_compile_args(self, args):
|
|
974
|
+
"""Get compile arguments."""
|
|
975
|
+
# this is used only for test
|
|
976
|
+
if is_auto_dynamic() and (self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None):
|
|
977
|
+
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
978
|
+
|
|
979
|
+
if self._dynamic_shape_inputs is not None:
|
|
980
|
+
logger.debug("Compiled Graph with dynamic shape")
|
|
981
|
+
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
982
|
+
Validator.check_symbolic_shape(self._dynamic_shape_inputs, args)
|
|
983
|
+
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
984
|
+
return self._dynamic_shape_inputs
|
|
985
|
+
return args
|
|
986
|
+
|
|
984
987
|
def compile(self, *args, **kwargs):
|
|
985
988
|
"""
|
|
986
989
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
@@ -989,19 +992,9 @@ class Cell(Cell_):
|
|
|
989
992
|
args (tuple): Args of the Cell object.
|
|
990
993
|
kwargs (dict): Kwargs of the Cell object.
|
|
991
994
|
"""
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
if self._dynamic_shape_inputs is None:
|
|
997
|
-
_cell_graph_executor.compile(self, phase=self.phase,
|
|
998
|
-
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
999
|
-
else:
|
|
1000
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
1001
|
-
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
1002
|
-
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
|
1003
|
-
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1004
|
-
logger.debug("Compiled Graph with dynamic shape")
|
|
995
|
+
self._compile_args = self._get_compile_args(args)
|
|
996
|
+
_cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
|
|
997
|
+
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1005
998
|
|
|
1006
999
|
def compile_and_run(self, *args, **kwargs):
|
|
1007
1000
|
"""
|
|
@@ -1019,7 +1012,7 @@ class Cell(Cell_):
|
|
|
1019
1012
|
"""
|
|
1020
1013
|
self.compile(*args, **kwargs)
|
|
1021
1014
|
self.add_flags(ge_sync_data=False)
|
|
1022
|
-
new_args = _get_args_for_run(self, args, kwargs)
|
|
1015
|
+
new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
|
|
1023
1016
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1024
1017
|
|
|
1025
1018
|
def auto_parallel_compile_and_run(self):
|
|
@@ -1070,14 +1063,14 @@ class Cell(Cell_):
|
|
|
1070
1063
|
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
1071
1064
|
"""
|
|
1072
1065
|
if not param_name:
|
|
1073
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1066
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1074
1067
|
if check_name_contain_dot and '.' in param_name:
|
|
1075
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not contain
|
|
1068
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
|
|
1076
1069
|
if '_params' not in self.__dict__:
|
|
1077
|
-
raise AttributeError("For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1070
|
+
raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1078
1071
|
if hasattr(self, param_name) and param_name not in self._params:
|
|
1079
|
-
raise KeyError("For 'insert_param_to_cell', the {} parameter already exists in the network.
|
|
1080
|
-
"insert another parameter with the same name."
|
|
1072
|
+
raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
|
|
1073
|
+
f"Cannot insert another parameter with the same name.")
|
|
1081
1074
|
if not isinstance(param, Parameter) and param is not None:
|
|
1082
1075
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1083
1076
|
f"but got {type(param)}.")
|
|
@@ -1139,11 +1132,11 @@ class Cell(Cell_):
|
|
|
1139
1132
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
1140
1133
|
f"but got {type(child_name)}.")
|
|
1141
1134
|
if not child_name or '.' in child_name:
|
|
1142
|
-
raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1143
|
-
"can not contain '.'")
|
|
1135
|
+
raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1136
|
+
"can not contain '.' ")
|
|
1144
1137
|
if hasattr(self, child_name) and child_name not in self._cells:
|
|
1145
|
-
raise KeyError("For 'insert_child_to_cell', the {} child cell already exists in the network.
|
|
1146
|
-
"insert another child cell with the same name."
|
|
1138
|
+
raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
|
|
1139
|
+
f"Cannot insert another child cell with the same name.")
|
|
1147
1140
|
if not isinstance(child_cell, Cell) and child_cell is not None:
|
|
1148
1141
|
raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
|
|
1149
1142
|
f"but got type {type(child_cell)}.")
|
|
@@ -1163,7 +1156,7 @@ class Cell(Cell_):
|
|
|
1163
1156
|
Returns:
|
|
1164
1157
|
Tensor, returns the computed result.
|
|
1165
1158
|
"""
|
|
1166
|
-
|
|
1159
|
+
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
|
1167
1160
|
|
|
1168
1161
|
def remove_redundant_parameters(self):
|
|
1169
1162
|
"""
|
|
@@ -1361,7 +1354,7 @@ class Cell(Cell_):
|
|
|
1361
1354
|
|
|
1362
1355
|
Tutorial Examples:
|
|
1363
1356
|
- `Model Training - Optimizer
|
|
1364
|
-
<https://mindspore.cn/tutorials/en/
|
|
1357
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
|
|
1365
1358
|
"""
|
|
1366
1359
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1367
1360
|
|
|
@@ -1472,7 +1465,7 @@ class Cell(Cell_):
|
|
|
1472
1465
|
|
|
1473
1466
|
Tutorial Examples:
|
|
1474
1467
|
- `Building a Network - Model Parameters
|
|
1475
|
-
<https://mindspore.cn/tutorials/en/
|
|
1468
|
+
<https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
|
|
1476
1469
|
"""
|
|
1477
1470
|
cells = []
|
|
1478
1471
|
if expand:
|
|
@@ -1698,6 +1691,9 @@ class Cell(Cell_):
|
|
|
1698
1691
|
if not hasattr(self, "_func_graph_flags"):
|
|
1699
1692
|
self._func_graph_flags = {}
|
|
1700
1693
|
self._func_graph_flags.update({**flags})
|
|
1694
|
+
if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
|
|
1695
|
+
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
1696
|
+
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
1701
1697
|
self.__dict__.update({**flags})
|
|
1702
1698
|
self._add_mixed_precision_flag(**flags)
|
|
1703
1699
|
return self
|
|
@@ -1808,7 +1804,7 @@ class Cell(Cell_):
|
|
|
1808
1804
|
accelerate the algorithm in the algorithm library.
|
|
1809
1805
|
|
|
1810
1806
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1811
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/
|
|
1807
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
|
|
1812
1808
|
|
|
1813
1809
|
Note:
|
|
1814
1810
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1865,7 +1861,7 @@ class Cell(Cell_):
|
|
|
1865
1861
|
|
|
1866
1862
|
Tutorial Examples:
|
|
1867
1863
|
- `Model Training - Implementing Training and Evaluation
|
|
1868
|
-
<https://mindspore.cn/tutorials/en/
|
|
1864
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
|
|
1869
1865
|
"""
|
|
1870
1866
|
if mode:
|
|
1871
1867
|
self._phase = 'train'
|
|
@@ -1959,8 +1955,8 @@ class Cell(Cell_):
|
|
|
1959
1955
|
hook_fn (function): Python function. Forward pre hook function.
|
|
1960
1956
|
|
|
1961
1957
|
Returns:
|
|
1962
|
-
|
|
1963
|
-
|
|
1958
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
1959
|
+
`handle.remove()` .
|
|
1964
1960
|
|
|
1965
1961
|
Raises:
|
|
1966
1962
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -1995,17 +1991,8 @@ class Cell(Cell_):
|
|
|
1995
1991
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
1996
1992
|
value= [ 2.00000000e+00]))
|
|
1997
1993
|
"""
|
|
1998
|
-
if
|
|
1999
|
-
logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
|
|
2000
|
-
f"context.set_context to set pynative mode.")
|
|
1994
|
+
if not check_hook_fn("register_forward_pre_hook", hook_fn):
|
|
2001
1995
|
return HookHandle()
|
|
2002
|
-
|
|
2003
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2004
|
-
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2005
|
-
f"function, but got {type(hook_fn)}.")
|
|
2006
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2007
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2008
|
-
|
|
2009
1996
|
self._enable_forward_pre_hook = True
|
|
2010
1997
|
_pynative_executor.set_hook_changed(self)
|
|
2011
1998
|
if not hasattr(self, '_forward_pre_hook_key'):
|
|
@@ -2059,8 +2046,8 @@ class Cell(Cell_):
|
|
|
2059
2046
|
hook_fn (function): Python function. Forward hook function.
|
|
2060
2047
|
|
|
2061
2048
|
Returns:
|
|
2062
|
-
|
|
2063
|
-
|
|
2049
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2050
|
+
`handle.remove()` .
|
|
2064
2051
|
|
|
2065
2052
|
Raises:
|
|
2066
2053
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2097,17 +2084,8 @@ class Cell(Cell_):
|
|
|
2097
2084
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2098
2085
|
value= [ 2.00000000e+00]))
|
|
2099
2086
|
"""
|
|
2100
|
-
if
|
|
2101
|
-
logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
|
|
2102
|
-
f"context.set_context to set pynative mode.")
|
|
2087
|
+
if not check_hook_fn("register_forward_hook", hook_fn):
|
|
2103
2088
|
return HookHandle()
|
|
2104
|
-
|
|
2105
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2106
|
-
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2107
|
-
f"function, but got {type(hook_fn)}.")
|
|
2108
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2109
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2110
|
-
|
|
2111
2089
|
self._enable_forward_hook = True
|
|
2112
2090
|
_pynative_executor.set_hook_changed(self)
|
|
2113
2091
|
if not hasattr(self, '_forward_hook_key'):
|
|
@@ -2159,8 +2137,8 @@ class Cell(Cell_):
|
|
|
2159
2137
|
hook_fn (function): Python function. Backward hook function.
|
|
2160
2138
|
|
|
2161
2139
|
Returns:
|
|
2162
|
-
|
|
2163
|
-
|
|
2140
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2141
|
+
`handle.remove()` .
|
|
2164
2142
|
|
|
2165
2143
|
Raises:
|
|
2166
2144
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2195,14 +2173,8 @@ class Cell(Cell_):
|
|
|
2195
2173
|
>>> print(output)
|
|
2196
2174
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2197
2175
|
"""
|
|
2198
|
-
if
|
|
2199
|
-
logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
|
|
2200
|
-
f"context.set_context to set pynative mode.")
|
|
2176
|
+
if not check_hook_fn("register_backward_hook", hook_fn):
|
|
2201
2177
|
return HookHandle()
|
|
2202
|
-
|
|
2203
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2204
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2205
|
-
f"function, but got {type(hook_fn)}.")
|
|
2206
2178
|
if self._cell_backward_hook is None:
|
|
2207
2179
|
self._enable_backward_hook = True
|
|
2208
2180
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
|
|
@@ -2232,10 +2204,16 @@ class Cell(Cell_):
|
|
|
2232
2204
|
else:
|
|
2233
2205
|
inputs = self._cell_backward_hook(*inputs)
|
|
2234
2206
|
inputs = (inputs,)
|
|
2235
|
-
if
|
|
2236
|
-
|
|
2207
|
+
if self.recompute_cell is not None:
|
|
2208
|
+
if isinstance(inputs, tuple):
|
|
2209
|
+
outputs = self.recompute_cell(*inputs, **kwargs)
|
|
2210
|
+
else:
|
|
2211
|
+
outputs = self.recompute_cell(inputs, **kwargs)
|
|
2237
2212
|
else:
|
|
2238
|
-
|
|
2213
|
+
if isinstance(inputs, tuple):
|
|
2214
|
+
outputs = self.construct(*inputs, **kwargs)
|
|
2215
|
+
else:
|
|
2216
|
+
outputs = self.construct(inputs, **kwargs)
|
|
2239
2217
|
outputs = self._cell_backward_hook(outputs)
|
|
2240
2218
|
return outputs
|
|
2241
2219
|
|
|
@@ -2365,6 +2343,9 @@ class Cell(Cell_):
|
|
|
2365
2343
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
2366
2344
|
Default: ``False`` .
|
|
2367
2345
|
"""
|
|
2346
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2347
|
+
self.recompute_cell = recompute_registry.get()(self.construct)
|
|
2348
|
+
return
|
|
2368
2349
|
self._recompute()
|
|
2369
2350
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2370
2351
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
@@ -2384,16 +2365,13 @@ class Cell(Cell_):
|
|
|
2384
2365
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
2385
2366
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
2386
2367
|
|
|
2368
|
+
@deprecated("2.3", "infer_param_pipeline_stage")
|
|
2387
2369
|
def infer_param_pipeline_stage(self):
|
|
2388
2370
|
"""
|
|
2389
2371
|
Infer pipeline stages of all parameters in the cell.
|
|
2390
2372
|
|
|
2391
2373
|
Note:
|
|
2392
|
-
-
|
|
2393
|
-
the parameter should use add_pipeline_stage to add it's pipeline_stage information.
|
|
2394
|
-
- If a parameter P has been used by two operators in different stages "stageA" and "stageB",
|
|
2395
|
-
the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
|
|
2396
|
-
to add it's stage information before using infer_param_pipeline_stage.
|
|
2374
|
+
- The interface is deprecated from version 2.3 and will be removed in a future version.
|
|
2397
2375
|
|
|
2398
2376
|
Returns:
|
|
2399
2377
|
The params belong to current stage in pipeline parallel.
|
|
@@ -2490,43 +2468,22 @@ class Cell(Cell_):
|
|
|
2490
2468
|
Args:
|
|
2491
2469
|
net_inputs (tuple): Inputs of the Cell object.
|
|
2492
2470
|
"""
|
|
2493
|
-
|
|
2494
|
-
|
|
2495
|
-
|
|
2496
|
-
|
|
2497
|
-
|
|
2471
|
+
if not getattr(set_inputs, '__ms_dynamic_len__', False):
|
|
2472
|
+
set_inputs_len = len(set_inputs)
|
|
2473
|
+
net_inputs_len = len(net_inputs)
|
|
2474
|
+
if set_inputs_len != net_inputs_len:
|
|
2475
|
+
raise ValueError(f"The length of 'set_inputs' or tuple(list) in 'set_inputs' "
|
|
2476
|
+
f"must be equal to network's inputs, but got 'set_inputs': "
|
|
2477
|
+
f"{set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2498
2478
|
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2499
2479
|
if isinstance(set_input, Tensor):
|
|
2500
2480
|
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2501
2481
|
elif isinstance(set_input, (tuple, list)):
|
|
2502
2482
|
if not isinstance(net_input, (tuple, list)):
|
|
2503
2483
|
raise TypeError(
|
|
2504
|
-
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in
|
|
2505
|
-
f"list, but got {type(net_input)}.")
|
|
2484
|
+
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
|
|
2485
|
+
f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
|
|
2506
2486
|
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2507
|
-
else:
|
|
2508
|
-
if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
|
|
2509
|
-
continue
|
|
2510
|
-
if net_input != set_input:
|
|
2511
|
-
raise ValueError(
|
|
2512
|
-
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
|
|
2513
|
-
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
|
|
2514
|
-
|
|
2515
|
-
def _run_tracefunc(self, *args, **kwargs):
|
|
2516
|
-
""" Run Packed Cell in Pack."""
|
|
2517
|
-
args = self._mixed_precision_cast(args)
|
|
2518
|
-
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2519
|
-
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2520
|
-
expander = PackExpander.get_instance()
|
|
2521
|
-
args = expander.begin_subgraph(self, *args)
|
|
2522
|
-
args = [_convert_tensor(a) for a in args]
|
|
2523
|
-
output = self._run_construct(args, kwargs)
|
|
2524
|
-
ret = expander.end_subgraph(self, output)
|
|
2525
|
-
output = _convert_tensor(ret)
|
|
2526
|
-
else:
|
|
2527
|
-
with _SetMixedPrecision(self):
|
|
2528
|
-
output = self._run_construct(args, kwargs)
|
|
2529
|
-
return output
|
|
2530
2487
|
|
|
2531
2488
|
def _mixed_precision_cast(self, inputs):
|
|
2532
2489
|
mixed_type = self.get_mixed_precision_type()
|
|
@@ -2633,7 +2590,7 @@ class GraphCell(Cell):
|
|
|
2633
2590
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2634
2591
|
if not self.obf_random_seed:
|
|
2635
2592
|
return self.compile_and_run(*args, **kwargs)
|
|
2636
|
-
append_input = Tensor((numpy.ones((1,
|
|
2593
|
+
append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
|
|
2637
2594
|
return self.compile_and_run(*args, append_input, **kwargs)
|
|
2638
2595
|
|
|
2639
2596
|
|