mindspore 2.2.14__cp39-cp39-manylinux1_x86_64.whl → 2.3.0rc1__cp39-cp39-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +5 -5
- mindspore/_c_dataengine.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +58 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +124 -204
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +3 -1
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_stub_tensor.py +6 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +91 -48
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +5 -4
- mindspore/common/dump.py +5 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +20 -11
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +19 -4
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +321 -433
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +167 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +60 -119
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +337 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +7 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6325 -1767
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/multiprocessing/__init__.py +68 -0
- mindspore/nn/cell.py +86 -133
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +4 -80
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +46 -38
- mindspore/nn/layer/padding.py +26 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -22
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +102 -56
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +74 -49
- mindspore/ops/_vmap/vmap_nn_ops.py +164 -89
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +133 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +248 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +147 -0
- mindspore/ops/auto_generate/gen_extend_func.py +130 -0
- mindspore/ops/auto_generate/gen_ops_def.py +4786 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +8335 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +77 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +166 -601
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +15 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +46 -0
- mindspore/ops/extend/array_func.py +152 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/{_op_impl/tbe/atomic_addr_clean.py → extend/nn_func.py} +5 -15
- mindspore/ops/function/__init__.py +19 -11
- mindspore/ops/function/array_func.py +251 -1440
- mindspore/ops/function/clip_func.py +12 -13
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +451 -2360
- mindspore/ops/function/nn_func.py +459 -780
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +24 -80
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +56 -62
- mindspore/ops/op_info_register.py +22 -19
- mindspore/ops/operations/__init__.py +19 -19
- mindspore/ops/operations/_grad_ops.py +20 -723
- mindspore/ops/operations/_inner_ops.py +178 -286
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +4 -34
- mindspore/ops/operations/array_ops.py +99 -2491
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +8 -8
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1391 -0
- mindspore/ops/operations/math_ops.py +703 -4601
- mindspore/ops/operations/nn_ops.py +374 -1748
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +248 -0
- mindspore/ops_generate/arg_handler.py +147 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +129 -0
- mindspore/ops_generate/gen_pyboost_func.py +932 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +364 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +5 -4
- mindspore/parallel/_auto_parallel_context.py +21 -76
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +30 -46
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +131 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -196
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +108 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +68 -285
- mindspore/profiler/parser/ascend_op_generator.py +75 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +293 -135
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +5 -0
- mindspore/profiler/profiling.py +296 -166
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +13 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +75 -6
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +256 -132
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/RECORD +515 -1061
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_creator_register.py +0 -37
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,28 +13,28 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
|
-
from typing import Union
|
|
16
|
+
from typing import Union, List, Dict
|
|
17
|
+
import types
|
|
17
18
|
import os
|
|
18
19
|
import ast
|
|
19
20
|
import sys
|
|
20
21
|
import inspect
|
|
22
|
+
import builtins
|
|
23
|
+
from textwrap import dedent
|
|
21
24
|
|
|
22
25
|
from mindspore import log as logger
|
|
23
|
-
from mindspore.nn import Cell, SequentialCell
|
|
24
|
-
from mindspore.ops import Primitive
|
|
25
|
-
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from
|
|
30
|
-
from
|
|
31
|
-
from
|
|
32
|
-
from
|
|
33
|
-
from
|
|
34
|
-
|
|
35
|
-
from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
|
|
36
|
-
from mindspore.rewrite.ast_helpers import AstReplacer
|
|
37
|
-
from ..common import error_str
|
|
26
|
+
from mindspore.nn import Cell, SequentialCell, CellList
|
|
27
|
+
from mindspore.ops.primitive import Primitive
|
|
28
|
+
import mindspore.ops.functional as F
|
|
29
|
+
from . import Parser, ParserRegister, reg_parser
|
|
30
|
+
from ..symbol_tree import SymbolTree
|
|
31
|
+
from ..node import Node, TreeNode, NodeManager, CallFunction, CellContainer, ControlFlow, LocalPrim
|
|
32
|
+
from ..api.scoped_value import ScopedValue
|
|
33
|
+
from ..ast_helpers import AstFlattener, AstConverter, AstFinder
|
|
34
|
+
from ..common.error_log import error_str
|
|
35
|
+
from ..common.namespace import is_subtree, is_ms_function, is_third_party
|
|
36
|
+
from ..common.namer import FunctionNamer
|
|
37
|
+
|
|
38
38
|
|
|
39
39
|
if sys.version_info >= (3, 9):
|
|
40
40
|
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
@@ -47,75 +47,27 @@ class AssignParser(Parser):
|
|
|
47
47
|
|
|
48
48
|
# Types for creating Cell Container node
|
|
49
49
|
types_for_cell_container = [SequentialCell,]
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
for tuple_elt in tuple_elts:
|
|
72
|
-
if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
|
|
73
|
-
raise RuntimeError(error_str(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
|
|
74
|
-
f"but got ast type {type(tuple_elt).__name__}",
|
|
75
|
-
child_node=tuple_elt, father_node=node))
|
|
76
|
-
if isinstance(tuple_elt, ast.Constant):
|
|
77
|
-
tuple_values.append(tuple_elt.value)
|
|
78
|
-
elif isinstance(tuple_elt, ast.Name):
|
|
79
|
-
tuple_values.append(tuple_elt.id)
|
|
80
|
-
elif isinstance(tuple_elt, ast.Attribute):
|
|
81
|
-
tuple_values.append("".join([tuple_elt.value.id, '.', tuple_elt.attr]))
|
|
82
|
-
return ScopedValue.create_variable_value(tuple(tuple_values))
|
|
83
|
-
|
|
84
|
-
@staticmethod
|
|
85
|
-
def _create_scopedvalue(node: ast.expr) -> ScopedValue:
|
|
86
|
-
"""
|
|
87
|
-
Create ScopedValue from an ast node.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
node (ast.expr): An ast node.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
An instance of ScopedValue.
|
|
94
|
-
|
|
95
|
-
Raises:
|
|
96
|
-
RuntimeError: Value of target of ast.Assign should be an ast.Name when target is an ast.Attribute.
|
|
97
|
-
RuntimeError: Type of input node is unsupported.
|
|
98
|
-
"""
|
|
99
|
-
if isinstance(node, ast.Name):
|
|
100
|
-
return ScopedValue.create_naming_value(node.id)
|
|
101
|
-
if isinstance(node, ast.Attribute):
|
|
102
|
-
scope = node.value
|
|
103
|
-
if not isinstance(scope, ast.Name):
|
|
104
|
-
raise RuntimeError(error_str(f"value of target of ast.Assign should be a ast.Name when target is a "
|
|
105
|
-
f"ast.Attribute, but got ast type '{type(scope).__name__}'",
|
|
106
|
-
child_node=scope, father_node=node))
|
|
107
|
-
return ScopedValue.create_naming_value(node.attr, scope.id)
|
|
108
|
-
if isinstance(node, ast.Tuple):
|
|
109
|
-
return AssignParser._create_scopedvalue_from_tuple_ast(node)
|
|
110
|
-
if isinstance(node, (ast.Constant, ast.NameConstant)):
|
|
111
|
-
return ScopedValue.create_variable_value(node.value)
|
|
112
|
-
if isinstance(node, ast.Num):
|
|
113
|
-
return ScopedValue.create_variable_value(node.n)
|
|
114
|
-
if isinstance(node, (ast.Str, ast.Bytes)):
|
|
115
|
-
return ScopedValue.create_variable_value(node.s)
|
|
116
|
-
raise RuntimeError(error_str(f"only support (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num"
|
|
117
|
-
f"ast.Str, ast.Bytes to argument), but got ast type '{type(node).__name__}'",
|
|
118
|
-
father_node=node))
|
|
50
|
+
# If mindspore built-in function to be parsered or skipped
|
|
51
|
+
_skip_ms_function = False
|
|
52
|
+
# Functions in black list will not be parsed
|
|
53
|
+
_function_parse_black_list = [F.arange]
|
|
54
|
+
# Share one implementation for the same instances
|
|
55
|
+
_share_one_implementation = False
|
|
56
|
+
# Implementation caches of sub SymbolTrees, CallFunction nodes and CellContainer nodes
|
|
57
|
+
# Keys are ids of the instance object
|
|
58
|
+
_cached_trees: Dict[int, SymbolTree] = {}
|
|
59
|
+
_cached_functions: Dict[int, Node] = {}
|
|
60
|
+
_cached_cell_containers: Dict[int, Node] = {}
|
|
61
|
+
|
|
62
|
+
def __init__(self):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self._variables_cache = []
|
|
65
|
+
self.stree: SymbolTree = None
|
|
66
|
+
self.ast_assign: ast.Assign = None
|
|
67
|
+
self.node_manager: NodeManager = None
|
|
68
|
+
self.targets: List[ScopedValue] = None
|
|
69
|
+
self.args: List[ScopedValue] = None
|
|
70
|
+
self.kwargs: Dict[str, ScopedValue] = None
|
|
119
71
|
|
|
120
72
|
@staticmethod
|
|
121
73
|
def _get_func_name(ast_call: ast.Call) -> str:
|
|
@@ -127,72 +79,45 @@ class AssignParser(Parser):
|
|
|
127
79
|
|
|
128
80
|
Returns:
|
|
129
81
|
Func name.
|
|
130
|
-
|
|
131
|
-
Raises:
|
|
132
|
-
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
|
133
82
|
"""
|
|
134
83
|
func = ast_call.func
|
|
135
84
|
if isinstance(func, ast.Name):
|
|
136
85
|
return func.id
|
|
137
86
|
if isinstance(func, ast.Attribute):
|
|
138
87
|
return func.attr
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
88
|
+
func_full_name = astunparse.unparse(func).strip()
|
|
89
|
+
if func_full_name.count('.') > 0:
|
|
90
|
+
return func_full_name.split('.')[-1]
|
|
91
|
+
return func_full_name
|
|
143
92
|
|
|
144
93
|
@staticmethod
|
|
145
|
-
def _get_func_scope(ast_call: ast.Call
|
|
94
|
+
def _get_func_scope(ast_call: ast.Call) -> str:
|
|
146
95
|
"""
|
|
147
96
|
Get the func scope from ast.Call.
|
|
148
97
|
|
|
149
98
|
Args:
|
|
150
99
|
ast_call (ast.Call): Input ast.Call node.
|
|
151
|
-
node_manager (NodeManager): NodeManager those asts belong to.
|
|
152
100
|
|
|
153
101
|
Returns:
|
|
154
102
|
Func scope.
|
|
155
|
-
|
|
156
|
-
Raises:
|
|
157
|
-
RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute.
|
|
158
|
-
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
|
159
103
|
"""
|
|
160
104
|
func = ast_call.func
|
|
161
105
|
if isinstance(func, ast.Name):
|
|
162
106
|
return ""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
if isinstance(func, ast.Call):
|
|
168
|
-
return AssignParser._get_func_scope(func, node_manager)
|
|
169
|
-
raise RuntimeError(error_str(f"funcValue should be Name or a Attribute or a Call, but got ast type "
|
|
170
|
-
f"'{type(func).__name__}'", child_node=func, father_node=ast_call))
|
|
107
|
+
func_full_name = astunparse.unparse(func).strip()
|
|
108
|
+
if func_full_name.count('.') > 0:
|
|
109
|
+
return func_full_name.rsplit('.', 1)[0]
|
|
110
|
+
return ""
|
|
171
111
|
|
|
172
112
|
@staticmethod
|
|
173
|
-
def
|
|
174
|
-
"""
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
symbol_name (str): Func name.
|
|
179
|
-
origin_net ([nn.Cell]): Network instance.
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
Symbol Object.
|
|
183
|
-
"""
|
|
184
|
-
var_dict = origin_net.__dict__
|
|
185
|
-
for key, value in var_dict["_cells"].items():
|
|
186
|
-
if key == symbol_name:
|
|
187
|
-
return value
|
|
188
|
-
|
|
189
|
-
for key, value in var_dict["_primitives"].items():
|
|
190
|
-
if key == symbol_name:
|
|
191
|
-
return value
|
|
192
|
-
return None
|
|
113
|
+
def _create_targets(ast_target: ast.AST) -> List[ScopedValue]:
|
|
114
|
+
"""Get targets from ast node."""
|
|
115
|
+
ast_target_elems = AstConverter.get_ast_target_elems(ast_target)
|
|
116
|
+
targets = [AstConverter.create_scopedvalue(ast_node) for ast_node in ast_target_elems]
|
|
117
|
+
return targets
|
|
193
118
|
|
|
194
119
|
@staticmethod
|
|
195
|
-
def _create_kwargs(keywords: [ast.keyword]) ->
|
|
120
|
+
def _create_kwargs(keywords: [ast.keyword]) -> Dict[str, ScopedValue]:
|
|
196
121
|
"""
|
|
197
122
|
Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node.
|
|
198
123
|
|
|
@@ -204,29 +129,133 @@ class AssignParser(Parser):
|
|
|
204
129
|
"""
|
|
205
130
|
results = {}
|
|
206
131
|
for keyword in keywords:
|
|
207
|
-
results[keyword.arg] =
|
|
132
|
+
results[keyword.arg] = AstConverter.create_scopedvalue(keyword.value)
|
|
208
133
|
return results
|
|
209
134
|
|
|
135
|
+
|
|
210
136
|
@staticmethod
|
|
211
|
-
def
|
|
137
|
+
def _get_inst_and_name(ast_node: ast.Attribute, stree: SymbolTree):
|
|
138
|
+
"""
|
|
139
|
+
Try to get instance object of ast_node from ast.Attribute.
|
|
212
140
|
"""
|
|
213
|
-
|
|
141
|
+
if not isinstance(ast_node, ast.Attribute):
|
|
142
|
+
return None, ""
|
|
143
|
+
scope_name = astunparse.unparse(ast_node).strip()
|
|
144
|
+
scope, name = scope_name.split('.', 1)
|
|
145
|
+
if scope != 'self':
|
|
146
|
+
return None, scope_name
|
|
147
|
+
if not hasattr(stree.get_origin_network(), name):
|
|
148
|
+
return None, scope_name
|
|
149
|
+
return getattr(stree.get_origin_network(), name), scope_name
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _list_of_cells(cell_list: list):
|
|
153
|
+
"""Check if elements in the list are all cells."""
|
|
154
|
+
for item in cell_list:
|
|
155
|
+
if not isinstance(item, Cell):
|
|
156
|
+
return False
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _get_path_of_node_manager(node_manager: NodeManager):
|
|
161
|
+
"""Get file path of type(instance) in NodeManager"""
|
|
162
|
+
node_manager = node_manager.get_top_manager()
|
|
163
|
+
if isinstance(node_manager, SymbolTree):
|
|
164
|
+
return inspect.getfile(type(node_manager.get_origin_network()))
|
|
165
|
+
return inspect.getfile(node_manager.get_instance())
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _get_module_of_node_manager(node_manager: NodeManager):
|
|
169
|
+
"""Get module where the node manager is located"""
|
|
170
|
+
# get module where function object is used
|
|
171
|
+
func_path = AssignParser._get_path_of_node_manager(node_manager)
|
|
172
|
+
func_path = os.path.normcase(os.path.normpath(func_path))
|
|
173
|
+
modules = list(sys.modules.values())
|
|
174
|
+
for m in modules:
|
|
175
|
+
if hasattr(m, "__file__") and m.__file__ is not None and func_path == os.path.normcase(m.__file__):
|
|
176
|
+
return m, func_path
|
|
177
|
+
return None, func_path
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def _get_object_from_module(func_full_name: str, module: types.ModuleType):
|
|
181
|
+
"""Get object from module according to full name of function"""
|
|
182
|
+
names = func_full_name.split('.')
|
|
183
|
+
obj = module
|
|
184
|
+
for attr in names:
|
|
185
|
+
if not hasattr(obj, attr):
|
|
186
|
+
logger.info(f"For '{func_full_name}', failed to get attr '{attr}' from '{obj}'")
|
|
187
|
+
return None
|
|
188
|
+
obj = getattr(obj, attr)
|
|
189
|
+
return obj
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def _get_local_var_provider(node_manager: NodeManager, var: str) -> Node:
|
|
193
|
+
"""Get the node providing specific variable"""
|
|
194
|
+
node = node_manager.get_tail()
|
|
195
|
+
while node is not None:
|
|
196
|
+
if var in [str(target) for target in node.get_targets()]:
|
|
197
|
+
return node
|
|
198
|
+
node = node.get_prev()
|
|
199
|
+
# When node_manager is control flow, nodes in upper node_manager need to be traversed.
|
|
200
|
+
if isinstance(node_manager, ControlFlow):
|
|
201
|
+
return AssignParser._get_local_var_provider(node_manager.get_node_manager(), var)
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
def target(self):
|
|
205
|
+
"""Parse target type."""
|
|
206
|
+
return ast.Assign
|
|
207
|
+
|
|
208
|
+
def store_env(self):
|
|
209
|
+
"""Store current environments"""
|
|
210
|
+
self._variables_cache.append(
|
|
211
|
+
[self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs])
|
|
212
|
+
self.stree = None
|
|
213
|
+
self.ast_assign = None
|
|
214
|
+
self.node_manager = None
|
|
215
|
+
self.targets = None
|
|
216
|
+
self.args = None
|
|
217
|
+
self.kwargs = None
|
|
218
|
+
|
|
219
|
+
def restore_env(self):
|
|
220
|
+
"""Restore last environments"""
|
|
221
|
+
self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs = \
|
|
222
|
+
self._variables_cache.pop()
|
|
223
|
+
|
|
224
|
+
def _get_cell_instance(self, func_scope, func_name):
|
|
225
|
+
"""
|
|
226
|
+
Get object instance from ast.Call with type of Cell.
|
|
214
227
|
|
|
215
228
|
Args:
|
|
216
229
|
func_scope (str): Func scope.
|
|
217
230
|
func_name (str): Func name.
|
|
218
|
-
stree (SymbolTree): Belong SymbolTree.
|
|
219
231
|
|
|
220
232
|
Returns:
|
|
221
233
|
An instance represents operator instance.
|
|
222
234
|
"""
|
|
223
235
|
if func_scope != "self":
|
|
224
236
|
return None
|
|
225
|
-
var_dict = stree.get_origin_network().__dict__
|
|
237
|
+
var_dict = self.stree.get_origin_network().__dict__
|
|
226
238
|
# Instance is of type Cell
|
|
227
239
|
for key, value in var_dict["_cells"].items():
|
|
228
240
|
if key == func_name:
|
|
229
241
|
return value
|
|
242
|
+
# Instance is of other type.
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def _get_primitive_instance(self, func_scope, func_name):
|
|
246
|
+
"""
|
|
247
|
+
Get object instance from ast.Call with type of Primitive.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
func_scope (str): Func scope.
|
|
251
|
+
func_name (str): Func name.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
An instance represents operator instance.
|
|
255
|
+
"""
|
|
256
|
+
if func_scope != "self":
|
|
257
|
+
return None
|
|
258
|
+
var_dict = self.stree.get_origin_network().__dict__
|
|
230
259
|
# Instance is of type Primitive
|
|
231
260
|
for key, value in var_dict["_primitives"].items():
|
|
232
261
|
if key == func_name:
|
|
@@ -234,46 +263,111 @@ class AssignParser(Parser):
|
|
|
234
263
|
# Instance is of other type.
|
|
235
264
|
return None
|
|
236
265
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
|
|
244
|
-
raise RuntimeError(f"For MindSpore Rewrite, only support str target in tuple, but got type "
|
|
245
|
-
f"{type(single_target).__name__}")
|
|
246
|
-
if single_target.type == ValueType.ConstantValue and isinstance(single_target.value, str):
|
|
247
|
-
single_target.type = ValueType.NamingValue
|
|
248
|
-
targets.append(single_target)
|
|
249
|
-
else:
|
|
250
|
-
targets.append(all_targets)
|
|
251
|
-
return targets
|
|
266
|
+
def _get_method_object(self, func_scope, func_name):
|
|
267
|
+
"""Get method object from network instance."""
|
|
268
|
+
stree = self.stree
|
|
269
|
+
if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name):
|
|
270
|
+
return getattr(stree.get_origin_network(), func_name)
|
|
271
|
+
return None
|
|
252
272
|
|
|
253
|
-
|
|
254
|
-
|
|
273
|
+
def _get_local_variable(self, func_scope, func_name) -> (bool, object):
|
|
274
|
+
"""
|
|
275
|
+
Get local variable
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
func_scope (str): Func scope.
|
|
279
|
+
func_name (str): Func name.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
bool: Indicate whether local variable is found.
|
|
283
|
+
object (Union[LocalPrim, type]): Instance of LocalPrim when calling the class, or class type
|
|
284
|
+
object when initializing the class.
|
|
285
|
+
"""
|
|
286
|
+
func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
|
|
287
|
+
# try to find func_name in class variables initializing the primitive during forward method
|
|
288
|
+
provider_node = None
|
|
289
|
+
if func_scope == "self":
|
|
290
|
+
for node in self.stree.local_prim_inits():
|
|
291
|
+
if func_full_name in [str(target) for target in node.get_targets()]:
|
|
292
|
+
provider_node = node
|
|
293
|
+
# try to find func_name in local variables
|
|
294
|
+
if provider_node is None:
|
|
295
|
+
provider_node = AssignParser._get_local_var_provider(self.node_manager, func_full_name)
|
|
296
|
+
if provider_node:
|
|
297
|
+
# when the node providering the local variable initialized a primitive during forward method,
|
|
298
|
+
# we use LocalPrim to indicate the instance of this primitive. e.g. :
|
|
299
|
+
# abs_inst = P.Abs() -> 'abs_inst' is an instance of primitive initialized locally
|
|
300
|
+
# y = abs_inst(x) -> here we are parsing now
|
|
301
|
+
cls_init = provider_node.get_init_cls()
|
|
302
|
+
if cls_init and inspect.isclass(cls_init) and issubclass(cls_init, Primitive):
|
|
303
|
+
return True, LocalPrim(cls_init)
|
|
304
|
+
# when the node providering the local variable represent a primitive type object, we return
|
|
305
|
+
# type-object to indicate that we are initializing this primitive. e.g. :
|
|
306
|
+
# abs_ops = _get_cache_prim(P.Abs) -> 'abs_ops' is an primitive type object
|
|
307
|
+
# y = abs_ops(x) -> here we are parsing now
|
|
308
|
+
cls_type = provider_node.get_type_cls()
|
|
309
|
+
if cls_type and inspect.isclass(cls_type) and issubclass(cls_type, Primitive):
|
|
310
|
+
return True, cls_type
|
|
311
|
+
# local variable whose type is not primitive instance
|
|
312
|
+
logger.info(f"Ignore local variable: {func_full_name}")
|
|
313
|
+
return True, None
|
|
314
|
+
# other local variable
|
|
315
|
+
if AssignParser._get_local_var_provider(self.node_manager, func_full_name.split('.')[0]):
|
|
316
|
+
logger.info(f"Ignore local variable: {func_full_name}")
|
|
317
|
+
return True, None
|
|
318
|
+
return False, None
|
|
319
|
+
|
|
320
|
+
def _get_function_object(self, func_scope, func_name, ast_call) -> (object, bool):
|
|
321
|
+
"""
|
|
322
|
+
Get function object from module.
|
|
323
|
+
|
|
324
|
+
If the code represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs),
|
|
325
|
+
return primitive type object with class type flag True.
|
|
326
|
+
|
|
327
|
+
if the code represent an initializtion of a class, e.g. abs_inst = P.Abs(),
|
|
328
|
+
return primitive type object with class type flag False.
|
|
329
|
+
|
|
330
|
+
if the code represent the call of function or class instance, e.g. y = abs_inst(x)/func(x),
|
|
331
|
+
return primitive instance or function object with class type flag False.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
func_scope (str): Func scope.
|
|
335
|
+
func_name (str): Func name.
|
|
336
|
+
ast_call (ast.Call): ast.Call of ast.Assign.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
object: Class type object, class instance or function object
|
|
340
|
+
bool: Flag indicate is node represent a class type object.
|
|
341
|
+
"""
|
|
342
|
+
func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
|
|
343
|
+
# get module where function object is used
|
|
344
|
+
module, func_path = AssignParser._get_module_of_node_manager(self.node_manager)
|
|
345
|
+
if module is None:
|
|
346
|
+
logger.debug(f"When getting object of '{func_full_name}', failed to find module in '{func_path}'")
|
|
347
|
+
return None, False
|
|
348
|
+
# if name of function is _get_cache_prim, return primitive type object
|
|
349
|
+
is_cls_type_obj = False
|
|
350
|
+
if func_full_name == '_get_cache_prim':
|
|
351
|
+
func_full_name = astunparse.unparse(ast_call.args[0]).strip()
|
|
352
|
+
is_cls_type_obj = True
|
|
353
|
+
# find object in module
|
|
354
|
+
obj = AssignParser._get_object_from_module(func_full_name, module)
|
|
355
|
+
return obj, is_cls_type_obj
|
|
356
|
+
|
|
357
|
+
def _update_field_in_init(self, func_name: str, sub_tree: SymbolTree) -> bool:
|
|
255
358
|
"""
|
|
256
359
|
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
|
257
360
|
Add the code like: `self.field = SubNetwork(self.field)`
|
|
258
361
|
|
|
259
362
|
Args:
|
|
260
|
-
|
|
261
|
-
func_name (str): A string represents function symbol.
|
|
262
|
-
stree (SymbolTree): The SymbolTree corresponding to main-network.
|
|
363
|
+
func_name (str): A string represents scope and name of function symbol.
|
|
263
364
|
sub_tree (SymbolTree): The SymbolTree corresponding to sub-network.
|
|
264
|
-
|
|
265
|
-
Raises:
|
|
266
|
-
NotImplementedError: If `func_scope` is not "self", it means corresponding op is inited in forward method.
|
|
267
|
-
NotImplementedError: If targets of ast.Assign of corresponding field in `__init__` method.
|
|
268
365
|
"""
|
|
269
|
-
|
|
270
|
-
logger.warning("Not support parse operator which is instantiated at runtime now: %s; name: %s", func_scope,
|
|
271
|
-
func_name)
|
|
272
|
-
init_func_ast = stree.get_init_func_ast()
|
|
366
|
+
init_func_ast = self.stree.get_init_func_ast()
|
|
273
367
|
sub_net_obj = sub_tree.get_origin_network()
|
|
274
368
|
sub_net_opt_name = sub_tree.get_opt_cls_name()
|
|
275
369
|
# Add .to_float(mindspore.float16) if origin subnet has this attribute
|
|
276
|
-
new_code = f"{
|
|
370
|
+
new_code = f"{func_name} = {sub_net_opt_name}({func_name})"
|
|
277
371
|
if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16:
|
|
278
372
|
new_code = f"{new_code}.to_float(mindspore.float16)"
|
|
279
373
|
elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16:
|
|
@@ -281,38 +375,7 @@ class AssignParser(Parser):
|
|
|
281
375
|
new_ast = ast.parse(new_code).body[0]
|
|
282
376
|
init_func_ast.body.append(new_ast)
|
|
283
377
|
|
|
284
|
-
|
|
285
|
-
def _create_inputs_for_cell_container(ast_assign) -> ['Node']:
|
|
286
|
-
"""Create inputs for cell container first node."""
|
|
287
|
-
call_ast_node = ast_assign.value
|
|
288
|
-
if not isinstance(call_ast_node, ast.Call):
|
|
289
|
-
raise RuntimeError(error_str(f"when creating input node for cellcontainer, value of input father ast node"
|
|
290
|
-
"is not ast.Call!'", child_node=call_ast_node, father_node=ast_assign))
|
|
291
|
-
first_node_inputs: ['Node'] = []
|
|
292
|
-
exist_param_name = []
|
|
293
|
-
for arg in call_ast_node.args:
|
|
294
|
-
if isinstance(arg, ast.Name):
|
|
295
|
-
param_name = arg.id
|
|
296
|
-
elif isinstance(arg, ast.arg):
|
|
297
|
-
param_name = arg.arg
|
|
298
|
-
else:
|
|
299
|
-
raise RuntimeError(error_str(f"only support ast.arg, ast.arg in arguments arg, but got "
|
|
300
|
-
f"'{type(arg).__name__}'", child_node=arg, father_node=call_ast_node))
|
|
301
|
-
if param_name in exist_param_name:
|
|
302
|
-
raise RuntimeError(error_str(f"Cellcontianer has duplicate input names", child_node=arg,
|
|
303
|
-
father_node=call_ast_node))
|
|
304
|
-
exist_param_name.append(param_name)
|
|
305
|
-
node = Node.create_input_node(arg, param_name, name=f"input_{param_name}")
|
|
306
|
-
first_node_inputs.append(node)
|
|
307
|
-
|
|
308
|
-
if call_ast_node.keywords:
|
|
309
|
-
raise RuntimeError(error_str(f"Not support keyword input for cellcontainer now.",
|
|
310
|
-
child_node=call_ast_node, father_node=ast_assign))
|
|
311
|
-
|
|
312
|
-
return first_node_inputs
|
|
313
|
-
|
|
314
|
-
@staticmethod
|
|
315
|
-
def _update_cell_container_in_init(stree, container_name, container_idx, subnet_opt_name):
|
|
378
|
+
def _update_cell_container_in_init(self, container_name, container_idx, subnet_opt_name):
|
|
316
379
|
"""
|
|
317
380
|
When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object.
|
|
318
381
|
So the assign code will be got from origin code first, and then be modified to new class name.
|
|
@@ -328,173 +391,300 @@ class AssignParser(Parser):
|
|
|
328
391
|
"""
|
|
329
392
|
new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])"
|
|
330
393
|
new_ast = ast.parse(new_code).body[0]
|
|
331
|
-
stree.get_init_func_ast().body.append(new_ast)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
394
|
+
self.stree.get_init_func_ast().body.append(new_ast)
|
|
395
|
+
|
|
396
|
+
def _add_import(self, import_name: str):
|
|
397
|
+
""" add import to current node manager."""
|
|
398
|
+
module, _ = AssignParser._get_module_of_node_manager(self.node_manager)
|
|
399
|
+
if module is None:
|
|
400
|
+
logger.info(f"Cannot get module where '{import_name}' is located, ignore import info")
|
|
401
|
+
return
|
|
402
|
+
node_manager = self.node_manager.get_top_manager()
|
|
403
|
+
belonging_ast = None if isinstance(node_manager, SymbolTree) else node_manager.get_manager_ast()
|
|
404
|
+
self.stree.add_import(module, import_name, belonging_ast)
|
|
405
|
+
|
|
406
|
+
def cell_container_process(self, func_name: str, node_name: str, container_obj: object):
|
|
336
407
|
""" parse cell container object."""
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
408
|
+
# create unparsable node if container is already parsed when sharing one implementation
|
|
409
|
+
if AssignParser._share_one_implementation and id(container_obj) in AssignParser._cached_cell_containers:
|
|
410
|
+
cell_container = Node.create_call_buildin_op(container_obj, self.ast_assign, self.targets,
|
|
411
|
+
func_name, self.args, self.kwargs, node_name)
|
|
412
|
+
return cell_container
|
|
413
|
+
cell_container = CellContainer(self.ast_assign, self.targets, func_name, self.args, self.kwargs,
|
|
414
|
+
node_name, self.stree, container_obj)
|
|
340
415
|
for i, cell in enumerate(container_obj):
|
|
341
416
|
cell_name = type(cell).__name__
|
|
342
|
-
|
|
343
|
-
if
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
417
|
+
# The type of cell is container of cells (e.g. SequentialCell)
|
|
418
|
+
if isinstance(cell, tuple(AssignParser.types_for_cell_container)):
|
|
419
|
+
sub_node = self.cell_container_process(f"{func_name}[{i}]", cell_name, cell)
|
|
420
|
+
elif is_subtree(cell):
|
|
421
|
+
# create unparsable node if tree node is already parsed when sharing one implementation
|
|
422
|
+
if AssignParser._share_one_implementation and id(cell) in AssignParser._cached_trees:
|
|
423
|
+
first_stree = AssignParser._cached_trees.get(id(cell))
|
|
424
|
+
self._update_cell_container_in_init(func_name, i, first_stree.get_opt_cls_name())
|
|
425
|
+
sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
|
|
426
|
+
self.kwargs, cell_name)
|
|
427
|
+
else:
|
|
428
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
429
|
+
stb = SymbolTreeBuilder(cell)
|
|
430
|
+
new_stree = stb.build()
|
|
431
|
+
sub_node = TreeNode.create_tree_node(new_stree, None, self.targets, cell_name, self.args,
|
|
432
|
+
self.kwargs, cell_name, cell)
|
|
433
|
+
self._update_cell_container_in_init(func_name, i, new_stree.get_opt_cls_name())
|
|
434
|
+
# save symbol tree if it is firstly parsed when sharing one implementation
|
|
435
|
+
if AssignParser._share_one_implementation:
|
|
436
|
+
AssignParser._cached_trees[id(cell)] = new_stree
|
|
349
437
|
else:
|
|
350
|
-
sub_node = Node.create_call_buildin_op(cell, None, targets, cell_name,
|
|
351
|
-
|
|
438
|
+
sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
|
|
439
|
+
self.kwargs, cell_name)
|
|
352
440
|
# add sub node to cell_container
|
|
353
441
|
cell_container.append(sub_node, False)
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
for idx, arg_provider in enumerate(first_node_inputs):
|
|
358
|
-
sub_node.set_arg_providers(idx, (arg_provider, 0))
|
|
359
|
-
else:
|
|
360
|
-
sub_node.set_arg_providers(0, (cell_container.node_list[i-1], 0))
|
|
442
|
+
# save the node if container is firstly parsed when sharing one implementation
|
|
443
|
+
if AssignParser._share_one_implementation:
|
|
444
|
+
AssignParser._cached_cell_containers[id(container_obj)] = cell_container
|
|
361
445
|
return cell_container
|
|
362
446
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
if
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
447
|
+
def process_cell(self, func_scope_name: ScopedValue, node_name: str, cell_inst: Cell):
|
|
448
|
+
"""Create CallCell node with instance of cell."""
|
|
449
|
+
# The type of cell is container of cells (e.g. SequentialCell)
|
|
450
|
+
if isinstance(cell_inst, tuple(AssignParser.types_for_cell_container)):
|
|
451
|
+
node = self.cell_container_process(func_scope_name, node_name, cell_inst)
|
|
452
|
+
# The type of cell is user custom network, then we create sub-symboltree
|
|
453
|
+
elif is_subtree(cell_inst):
|
|
454
|
+
# create unparsable node if tree node is already parsed when sharing one implementation
|
|
455
|
+
if AssignParser._share_one_implementation and id(cell_inst) in AssignParser._cached_trees:
|
|
456
|
+
first_stree = AssignParser._cached_trees.get(id(cell_inst))
|
|
457
|
+
self._update_field_in_init(str(func_scope_name), first_stree)
|
|
458
|
+
node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name,
|
|
459
|
+
self.args, self.kwargs, node_name)
|
|
460
|
+
else:
|
|
461
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
462
|
+
stb = SymbolTreeBuilder(cell_inst)
|
|
463
|
+
new_stree = stb.build()
|
|
464
|
+
self._update_field_in_init(str(func_scope_name), new_stree)
|
|
465
|
+
node = TreeNode.create_tree_node(new_stree, self.ast_assign, self.targets, func_scope_name,
|
|
466
|
+
self.args, self.kwargs, node_name, new_stree.get_origin_network())
|
|
467
|
+
# save symbol tree if it is firstly parsed when sharing one implementation
|
|
468
|
+
if AssignParser._share_one_implementation:
|
|
469
|
+
AssignParser._cached_trees[id(cell_inst)] = new_stree
|
|
470
|
+
else:
|
|
471
|
+
# The type of cell is built-in cells
|
|
472
|
+
node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name, self.args,
|
|
473
|
+
self.kwargs, node_name)
|
|
474
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
475
|
+
|
|
476
|
+
def process_primitive(self, func_scope_name: ScopedValue, node_name: str, primitive_inst: Primitive):
|
|
477
|
+
"""Create CallPrimitive node with instance of primitive."""
|
|
478
|
+
node = Node.create_call_buildin_op(primitive_inst, self.ast_assign, self.targets, func_scope_name,
|
|
479
|
+
self.args, self.kwargs, node_name)
|
|
480
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
481
|
+
|
|
482
|
+
def process_class_method(self, func_scope_name: ScopedValue, node_name: str, method_object: object):
|
|
483
|
+
"""Create CallFunction node for class method function."""
|
|
484
|
+
func_name = func_scope_name.value
|
|
485
|
+
# get ast.FunctionDef
|
|
389
486
|
ast_functiondef = None
|
|
390
|
-
for body in stree.get_class_ast().body:
|
|
487
|
+
for body in self.stree.get_class_ast().body:
|
|
391
488
|
if isinstance(body, ast.FunctionDef) and func_name == body.name:
|
|
392
489
|
ast_functiondef = body
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
490
|
+
if ast_functiondef is None:
|
|
491
|
+
# method of child class may be called and will be ignored now.
|
|
492
|
+
logger.info(error_str(f"Find ast of function '{func_name}' in network '{self.stree.get_ori_cls_name()}' "
|
|
493
|
+
f"failed", child_node=self.ast_assign))
|
|
494
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
|
|
495
|
+
else:
|
|
496
|
+
# create CallFunction node
|
|
497
|
+
self.insert_callfunction_node(func_scope_name, node_name, ast_functiondef, method_object, True)
|
|
498
|
+
|
|
499
|
+
def process_function(self, func_scope_name: ScopedValue, node_name: str, function_object: object,
|
|
500
|
+
is_cls_type_obj: bool):
|
|
501
|
+
"""Create node for function."""
|
|
502
|
+
# Ignore functions in _function_parse_black_list
|
|
503
|
+
if function_object in AssignParser._function_parse_black_list:
|
|
504
|
+
logger.debug(f"'{func_scope_name}' is in the _function_parse_black_list and will not be parsed")
|
|
505
|
+
if not func_scope_name.scope:
|
|
506
|
+
self._add_import(func_scope_name.value)
|
|
507
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
508
|
+
return
|
|
509
|
+
# break loop function
|
|
510
|
+
node_manager = self.node_manager
|
|
511
|
+
while node_manager and isinstance(node_manager, Node):
|
|
512
|
+
if isinstance(node_manager, CallFunction) and node_manager.get_instance() == function_object:
|
|
513
|
+
logger.info(f"loop function detected in '{func_scope_name}', stop parsing function.")
|
|
514
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
515
|
+
return
|
|
516
|
+
node_manager = node_manager.get_node_manager()
|
|
517
|
+
# process primitive instances:
|
|
518
|
+
# (global/local) _ops_func = P.FUNC()
|
|
519
|
+
# (here) y = _ops_func(x) <- (process: _ops_func)
|
|
520
|
+
if isinstance(function_object, Primitive):
|
|
521
|
+
# when primitive instance is not a local variable, it will be a global object which need to be imported
|
|
522
|
+
if not isinstance(function_object, LocalPrim):
|
|
523
|
+
import_name = str(func_scope_name).split('.')[0]
|
|
524
|
+
self._add_import(import_name)
|
|
525
|
+
# create CallPrimitive node
|
|
526
|
+
self.process_primitive(func_scope_name, func_scope_name.value, function_object)
|
|
527
|
+
return
|
|
528
|
+
# process primitive object:
|
|
529
|
+
# (here) _ops_func = P.FUNC() <- (process: P.FUNC)
|
|
530
|
+
# (later) y = _ops_func(x)
|
|
531
|
+
if inspect.isclass(function_object):
|
|
532
|
+
node = self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
|
|
533
|
+
if is_cls_type_obj:
|
|
534
|
+
# represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
|
|
535
|
+
node.set_type_cls(function_object)
|
|
536
|
+
# add import
|
|
537
|
+
if str(func_scope_name) == '_get_cache_prim':
|
|
538
|
+
import_name = astunparse.unparse(self.ast_assign.value.args[0]).strip()
|
|
539
|
+
if '.' not in import_name:
|
|
540
|
+
self._add_import(import_name)
|
|
541
|
+
else:
|
|
542
|
+
# represent the initialize of a class type, e.g. abs_inst = P.Abs()
|
|
543
|
+
node.set_init_cls(function_object)
|
|
544
|
+
# record local primitive objects
|
|
545
|
+
if func_scope_name.scope == 'self' and issubclass(function_object, Primitive):
|
|
546
|
+
self.stree.local_prim_inits.append(node)
|
|
547
|
+
return
|
|
548
|
+
# process third party functions
|
|
549
|
+
is_ms_func = is_ms_function(function_object)
|
|
550
|
+
if not is_ms_func and is_third_party(function_object):
|
|
551
|
+
logger.info(f"Ignore third party function '{func_scope_name}'.")
|
|
552
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
553
|
+
return
|
|
554
|
+
# process mindspore functions
|
|
555
|
+
if is_ms_func and AssignParser._skip_ms_function:
|
|
556
|
+
logger.info(f"Ignore mindspore function '{func_scope_name}'.")
|
|
557
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
558
|
+
return
|
|
559
|
+
# get ast.FunctionDef
|
|
560
|
+
source_code = inspect.getsource(function_object)
|
|
561
|
+
ast_functiondef = ast.parse(dedent(source_code)).body[0]
|
|
562
|
+
if not isinstance(ast_functiondef, ast.FunctionDef):
|
|
563
|
+
logger.info(error_str(f"Get ast.FunctionDef of function {str(func_scope_name)} failed, the type of "
|
|
564
|
+
f"ast node is {type(ast_functiondef)}", child_node=self.ast_assign))
|
|
565
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
566
|
+
return
|
|
567
|
+
if [n for n in ast_functiondef.body if isinstance(n, ast.FunctionDef)]:
|
|
568
|
+
logger.info(error_str(f"closure syntax is not supported now, {str(func_scope_name)} will not be parsed.",
|
|
569
|
+
child_node=ast_functiondef))
|
|
570
|
+
if not func_scope_name.scope:
|
|
571
|
+
self._add_import(func_scope_name.value)
|
|
572
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
573
|
+
return
|
|
574
|
+
# update func_name, and remove scope
|
|
575
|
+
new_name = ast_functiondef.name
|
|
576
|
+
# when func_scope_name(e.g. 'C.uniform') is not the name in ast.FunctionDef(e.g. 'uniform'), this name may be
|
|
577
|
+
# already used as variable(e.g. uniform = C.uniform(x)).
|
|
578
|
+
# To avoid new function's name being duplicated with existed variable, an suffix '_opt' will be added.
|
|
579
|
+
if new_name != str(func_scope_name):
|
|
580
|
+
new_name = f"{new_name}_opt"
|
|
581
|
+
new_name = FunctionNamer().instance().get_name(new_name)
|
|
582
|
+
# create unparsable node if function is already parsed when sharing one implementation
|
|
583
|
+
if AssignParser._share_one_implementation and id(function_object) in AssignParser._cached_functions:
|
|
584
|
+
first_node = AssignParser._cached_functions.get(id(function_object))
|
|
585
|
+
ast_call: ast.Call = self.ast_assign.value
|
|
586
|
+
ast_call.func = ast.Name(id=str(first_node.get_func_name()), ctx=ast.Load())
|
|
587
|
+
self.insert_callfunction_node(func_scope_name, new_name, None, function_object, False)
|
|
588
|
+
return
|
|
589
|
+
ast_functiondef.name = new_name
|
|
590
|
+
ast_call: ast.Call = self.ast_assign.value
|
|
591
|
+
ast_call.func = ast.Name(id=new_name, ctx=ast.Load())
|
|
592
|
+
# save ast.FunctionDef into stree._external_ast
|
|
593
|
+
self.stree.get_external_ast()[ast_functiondef] = []
|
|
594
|
+
# import module which function defined in
|
|
595
|
+
func_file_path = inspect.getabsfile(function_object)
|
|
596
|
+
self.stree.save_imports_from_file(func_file_path, ast_functiondef)
|
|
597
|
+
# create CallFunction node
|
|
598
|
+
func_scope_name = ScopedValue.create_naming_value(new_name, "")
|
|
599
|
+
node = self.insert_callfunction_node(func_scope_name, new_name, ast_functiondef, function_object, False)
|
|
600
|
+
# save function node if it is firstly parsed when sharing one implementation
|
|
601
|
+
if AssignParser._share_one_implementation:
|
|
602
|
+
AssignParser._cached_functions[id(function_object)] = node
|
|
603
|
+
|
|
604
|
+
def insert_callfunction_node(self, func_name: ScopedValue, node_name: str, ast_functiondef: ast.FunctionDef,
|
|
605
|
+
func_obj: object, is_method: bool) -> Node:
|
|
606
|
+
"""Create CallFunction node for function."""
|
|
607
|
+
if ast_functiondef is None:
|
|
608
|
+
node = Node.inner_create_call_function(node_name, self.ast_assign, func_name, func_obj,
|
|
609
|
+
self.targets, self.args, self.kwargs)
|
|
610
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
611
|
+
return node
|
|
612
|
+
# create CallFunction node
|
|
613
|
+
node = CallFunction(self.targets, func_name, self.args, self.kwargs, node_name, self.ast_assign,
|
|
614
|
+
ast_functiondef, self.stree, func_obj, is_method)
|
|
615
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
401
616
|
# expand ast codes
|
|
402
|
-
ast_functiondef =
|
|
617
|
+
ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name.value], self.stree)
|
|
403
618
|
# parse ast codes into CallFunction Node
|
|
404
619
|
parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
405
|
-
parser.process(stree, ast_functiondef, node_manager=node)
|
|
620
|
+
parser.process(self.stree, ast_functiondef, node_manager=node)
|
|
406
621
|
return node
|
|
407
622
|
|
|
408
|
-
def
|
|
409
|
-
node_manager: NodeManager) -> Node:
|
|
623
|
+
def process_ast_call(self, ast_call: ast.Call):
|
|
410
624
|
"""
|
|
411
625
|
Convert ast.Call to a symbol tree node.
|
|
412
626
|
|
|
413
627
|
Args:
|
|
414
628
|
ast_call (ast.Call): An ast.Call of assign node in construct.
|
|
415
|
-
ast_assign (ast.Assign): Assign node in construct.
|
|
416
|
-
stree (SymbolTree): Symbol Tree under parsing.
|
|
417
|
-
node_manager (NodeManager): NodeManager those asts belong to.
|
|
418
|
-
|
|
419
|
-
Returns:
|
|
420
|
-
An instance of Node in Symbol Tree.
|
|
421
|
-
|
|
422
|
-
Raises:
|
|
423
|
-
RuntimeError: If operator instance invoked by assign is undefined.
|
|
424
629
|
"""
|
|
425
|
-
targets = AssignParser.
|
|
630
|
+
self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
631
|
+
self.args = [AstConverter.create_scopedvalue(arg) for arg in ast_call.args]
|
|
632
|
+
self.kwargs = AssignParser._create_kwargs(ast_call.keywords)
|
|
426
633
|
func_name = AssignParser._get_func_name(ast_call)
|
|
427
|
-
|
|
428
|
-
raise RuntimeError("function name not exist")
|
|
429
|
-
func_scope = AssignParser._get_func_scope(ast_call, node_manager)
|
|
634
|
+
func_scope = AssignParser._get_func_scope(ast_call)
|
|
430
635
|
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
636
|
+
func_full_name = str(func_scope_name)
|
|
637
|
+
# y = func(xxx)(xxx) / y = func1(xxx).func2(xxx) is not supported, and should be flattened before parsing.
|
|
638
|
+
if AstFinder(ast_call.func).find_all(ast.Call):
|
|
639
|
+
logger.info(error_str("ast.Call in func name of ast.Call is not supported.", ast_call, self.ast_assign))
|
|
640
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
641
|
+
return
|
|
642
|
+
# Ignore built-in functions
|
|
643
|
+
if func_full_name in dir(builtins):
|
|
644
|
+
logger.info(f"Ignore built-in function: {func_scope_name}")
|
|
645
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
646
|
+
return
|
|
647
|
+
# Ignore function name is target of for loop
|
|
648
|
+
if isinstance(self.node_manager, ControlFlow) and func_full_name in self.node_manager.loop_vars:
|
|
649
|
+
logger.info(f"Ignore function of loop variable: {func_scope_name}")
|
|
650
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
651
|
+
return
|
|
652
|
+
# Instance with type of Cell
|
|
653
|
+
cell_inst = self._get_cell_instance(func_scope, func_name)
|
|
654
|
+
if cell_inst is not None:
|
|
655
|
+
self.process_cell(func_scope_name, func_name, cell_inst)
|
|
656
|
+
return
|
|
657
|
+
# Instance with type of Primitive
|
|
658
|
+
primitive_inst = self._get_primitive_instance(func_scope, func_name)
|
|
659
|
+
if primitive_inst is not None:
|
|
660
|
+
self.process_primitive(func_scope_name, func_name, primitive_inst)
|
|
661
|
+
return
|
|
662
|
+
# Class method object
|
|
663
|
+
method_object = self._get_method_object(func_scope, func_name)
|
|
664
|
+
if method_object is not None:
|
|
665
|
+
if inspect.ismethod(method_object):
|
|
666
|
+
self.process_class_method(func_scope_name, func_name, method_object)
|
|
451
667
|
else:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
AssignParser._update_field_in_init(func_scope, func_name, stree, new_stree)
|
|
473
|
-
replacer = AstReplacer(new_stree.get_class_ast())
|
|
474
|
-
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
|
475
|
-
return TreeNode.create_tree_node(new_stree, ast_assign, targets, func_scope_name, call_args,
|
|
476
|
-
call_kwargs, func_name, new_stree.get_origin_network())
|
|
477
|
-
# Instance of function is buildin cells
|
|
478
|
-
return Node.create_call_buildin_op(func_inst, ast_assign, targets, func_scope_name, call_args, call_kwargs,
|
|
479
|
-
func_name)
|
|
480
|
-
raise RuntimeError("For MindSpore Rewrite, unsupported operation in ast.Call found: ",
|
|
481
|
-
type(func_inst).__name__)
|
|
482
|
-
|
|
483
|
-
@staticmethod
|
|
484
|
-
def _tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
|
|
485
|
-
""" check whether each element's type in tuple is supported by scopled value. """
|
|
486
|
-
if not isinstance(value, ast.Tuple):
|
|
487
|
-
raise RuntimeError("For AssignParser._tuple_elts_support_scopledvalue(), the type of value should be "
|
|
488
|
-
f"Tuple, but got {type(value).__name__}")
|
|
489
|
-
|
|
490
|
-
for elt in value.elts:
|
|
491
|
-
if not isinstance(elt, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)):
|
|
492
|
-
return False
|
|
493
|
-
return True
|
|
494
|
-
|
|
495
|
-
@staticmethod
|
|
496
|
-
def _convert_ast_mathops_to_node(ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
|
|
497
|
-
ast_assign: ast.Assign) -> Node:
|
|
668
|
+
self.process_function(func_scope_name, func_name, method_object, False)
|
|
669
|
+
return
|
|
670
|
+
# Local variable
|
|
671
|
+
is_local_var, primitive_obj = self._get_local_variable(func_scope, func_name)
|
|
672
|
+
if primitive_obj is not None:
|
|
673
|
+
self.process_function(func_scope_name, func_name, primitive_obj, False)
|
|
674
|
+
return
|
|
675
|
+
if is_local_var:
|
|
676
|
+
# for a variable whose type is not primitive instance, create normal node for it
|
|
677
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
678
|
+
return
|
|
679
|
+
# Function object
|
|
680
|
+
function_object, is_cls_type_obj = self._get_function_object(func_scope, func_name, ast_call)
|
|
681
|
+
if function_object is not None:
|
|
682
|
+
self.process_function(func_scope_name, func_name, function_object, is_cls_type_obj)
|
|
683
|
+
return
|
|
684
|
+
logger.info(error_str("Failed to get instance or object of ast.Call.", ast_call, self.ast_assign))
|
|
685
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
686
|
+
|
|
687
|
+
def process_ast_mathops(self, ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]):
|
|
498
688
|
"""
|
|
499
689
|
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
|
|
500
690
|
a symbol tree node.
|
|
@@ -502,10 +692,6 @@ class AssignParser(Parser):
|
|
|
502
692
|
Args:
|
|
503
693
|
ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
|
|
504
694
|
operation in construct function.
|
|
505
|
-
ast_assign (ast.Assign): Assign node in construct.
|
|
506
|
-
|
|
507
|
-
Returns:
|
|
508
|
-
An instance of Node in Symbol Tree.
|
|
509
695
|
|
|
510
696
|
Raises:
|
|
511
697
|
TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
|
|
@@ -515,38 +701,105 @@ class AssignParser(Parser):
|
|
|
515
701
|
raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
516
702
|
"ast.BoolOp, ast.Compare), but got ", type(ast_op))
|
|
517
703
|
|
|
518
|
-
targets = AssignParser.
|
|
704
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
519
705
|
args = []
|
|
520
706
|
op_type_str = type(ast_op).__name__
|
|
521
707
|
op_type = ScopedValue.create_naming_value(op_type_str)
|
|
522
|
-
ops = {}
|
|
523
708
|
name = op_type_str
|
|
524
709
|
if isinstance(ast_op, ast.BinOp):
|
|
525
710
|
op = type(ast_op.op).__name__
|
|
526
711
|
name = f'{name}_{op}'
|
|
527
|
-
|
|
528
|
-
args.append(
|
|
529
|
-
args.append(AssignParser._create_scopedvalue(ast_op.right))
|
|
712
|
+
args.append(AstConverter.create_scopedvalue(ast_op.left))
|
|
713
|
+
args.append(AstConverter.create_scopedvalue(ast_op.right))
|
|
530
714
|
elif isinstance(ast_op, ast.UnaryOp):
|
|
531
715
|
op = type(ast_op.op).__name__
|
|
532
716
|
name = f'{name}_{op}'
|
|
533
|
-
|
|
534
|
-
args.append(AssignParser._create_scopedvalue(ast_op.operand))
|
|
717
|
+
args.append(AstConverter.create_scopedvalue(ast_op.operand))
|
|
535
718
|
elif isinstance(ast_op, ast.BoolOp):
|
|
536
719
|
op = type(ast_op.op).__name__
|
|
537
720
|
name = f'{name}_{op}'
|
|
538
|
-
ops['0'] = ScopedValue.create_naming_value(op)
|
|
539
721
|
for value in ast_op.values:
|
|
540
|
-
args.append(
|
|
722
|
+
args.append(AstConverter.create_scopedvalue(value))
|
|
541
723
|
elif isinstance(ast_op, ast.Compare):
|
|
542
|
-
args.append(
|
|
724
|
+
args.append(AstConverter.create_scopedvalue(ast_op.left))
|
|
543
725
|
for idx, ast_cmp_op in enumerate(ast_op.ops):
|
|
544
726
|
op = type(ast_cmp_op).__name__
|
|
545
727
|
name = f'{name}_{op}'
|
|
546
|
-
|
|
547
|
-
args.append(AssignParser._create_scopedvalue(ast_op.comparators[idx]))
|
|
728
|
+
args.append(AstConverter.create_scopedvalue(ast_op.comparators[idx]))
|
|
548
729
|
name = name.lower()
|
|
549
|
-
|
|
730
|
+
node = Node.create_mathops_node(self.ast_assign, targets, op_type, args, name)
|
|
731
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
732
|
+
|
|
733
|
+
def process_ast_constant(self, ast_constant: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str]):
|
|
734
|
+
"""
|
|
735
|
+
Convert ast node of constant types (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str) to
|
|
736
|
+
a symbol tree node.
|
|
737
|
+
"""
|
|
738
|
+
node_name = f"{type(ast_constant).__name__.lower()}_assign"
|
|
739
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
740
|
+
args = [AstConverter.create_scopedvalue(ast_constant)]
|
|
741
|
+
node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, node_name)
|
|
742
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
743
|
+
|
|
744
|
+
def process_ast_name(self, ast_node: Union[ast.Name, ast.Attribute]):
|
|
745
|
+
"""
|
|
746
|
+
Convert ast node of ast.Name and ast.Attribute to a symbol tree node.
|
|
747
|
+
"""
|
|
748
|
+
self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
749
|
+
inst, scope_name = AssignParser._get_inst_and_name(ast_node, self.stree)
|
|
750
|
+
if inst is not None and (isinstance(inst, CellList) or
|
|
751
|
+
isinstance(inst, list) and AssignParser._list_of_cells(inst)):
|
|
752
|
+
node = self.cell_container_process(scope_name, scope_name, inst)
|
|
753
|
+
else:
|
|
754
|
+
node_name = f"{type(ast_node).__name__.lower()}_assign"
|
|
755
|
+
args = [AstConverter.create_scopedvalue(ast_node)]
|
|
756
|
+
node = Node.create_call_method(self.ast_assign, self.targets, "pass_through", args, {}, node_name)
|
|
757
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
758
|
+
|
|
759
|
+
def process_ast_tuple(self, ast_node: Union[ast.Tuple, ast.List]):
|
|
760
|
+
"""
|
|
761
|
+
Convert ast node of ast.Tuple or ast.List to a symbol tree node.
|
|
762
|
+
"""
|
|
763
|
+
# ensure that each element's type in tuple is supported by scopled value
|
|
764
|
+
if AstConverter.ast_tuple_elts_support_scopledvalue(ast_node):
|
|
765
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
766
|
+
args = []
|
|
767
|
+
for elt in ast_node.elts:
|
|
768
|
+
args.append(AstConverter.create_scopedvalue(elt))
|
|
769
|
+
func_name = "tuple" if isinstance(ast_node, ast.Tuple) else "list"
|
|
770
|
+
node = Node.create_call_method(self.ast_assign, targets, func_name, args, {}, func_name)
|
|
771
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
772
|
+
else:
|
|
773
|
+
logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
|
|
774
|
+
"in rewrite, fallback to python")
|
|
775
|
+
self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
|
|
776
|
+
|
|
777
|
+
def process_ast_dict(self, ast_dict: ast.Dict):
|
|
778
|
+
"""
|
|
779
|
+
Convert ast node of ast.Dict to a symbol tree node.
|
|
780
|
+
"""
|
|
781
|
+
# ensure that each element's type in dict is supported by scopled value
|
|
782
|
+
if AstConverter.ast_dict_support_scopledvalue(ast_dict):
|
|
783
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
784
|
+
kwargs = {}
|
|
785
|
+
for idx, key in enumerate(ast_dict.keys):
|
|
786
|
+
kwargs[key.value] = AstConverter.create_scopedvalue(ast_dict.values[idx])
|
|
787
|
+
func_name = ScopedValue.create_naming_value("dict")
|
|
788
|
+
node = Node.create_call_method(self.ast_assign, targets, func_name, [], kwargs, "dict")
|
|
789
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
790
|
+
else:
|
|
791
|
+
logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
|
|
792
|
+
"in rewrite, fallback to python")
|
|
793
|
+
self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
|
|
794
|
+
|
|
795
|
+
def process_ast_subscript(self, ast_subscript: ast.Subscript):
|
|
796
|
+
"""
|
|
797
|
+
Convert ast node of ast.Subscript to a symbol tree node.
|
|
798
|
+
"""
|
|
799
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
800
|
+
args = [AstConverter.create_scopedvalue(ast_subscript)]
|
|
801
|
+
node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, "subscript_var")
|
|
802
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
550
803
|
|
|
551
804
|
def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager):
|
|
552
805
|
"""
|
|
@@ -561,68 +814,37 @@ class AssignParser(Parser):
|
|
|
561
814
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
562
815
|
node ([ast.Assign]): An ast.Assign node.
|
|
563
816
|
node_manager (NodeManager): NodeManager those asts belong to.
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
node_name = "attribute_assign"
|
|
594
|
-
else:
|
|
595
|
-
node_name = "other_assign"
|
|
596
|
-
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
|
597
|
-
call_args = [AssignParser._create_scopedvalue(value)]
|
|
598
|
-
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
|
|
599
|
-
stree.append_origin_field(node_, node_manager)
|
|
600
|
-
elif isinstance(value, ast.Tuple):
|
|
601
|
-
if AssignParser._tuple_elts_support_scopledvalue(value):
|
|
602
|
-
# ensure that each element's type in tuple is supported by scopled value
|
|
603
|
-
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
|
604
|
-
args = []
|
|
605
|
-
for elt in value.elts:
|
|
606
|
-
args.append(AssignParser._create_scopedvalue(elt))
|
|
607
|
-
node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"),
|
|
608
|
-
args, {}, "tuple")
|
|
609
|
-
stree.append_origin_field(node_, node_manager)
|
|
610
|
-
else:
|
|
611
|
-
logger.info(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
|
|
612
|
-
"in rewrite, fallback to python")
|
|
613
|
-
stree.try_append_python_node(node, node, node_manager)
|
|
614
|
-
elif isinstance(value, (ast.List, ast.Dict)):
|
|
615
|
-
# add these as callmethod node if necessary
|
|
616
|
-
stree.try_append_python_node(node, node, node_manager)
|
|
617
|
-
else:
|
|
618
|
-
raise RuntimeError(
|
|
619
|
-
error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
|
|
620
|
-
f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, "
|
|
621
|
-
f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
|
|
622
|
-
child_node=value, father_node=node))
|
|
623
|
-
except RuntimeError:
|
|
624
|
-
logger.info(f"ops-call({astunparse.unparse(node).strip()}) not supported in rewrite, fallback to python")
|
|
817
|
+
"""
|
|
818
|
+
if len(node.targets) != 1:
|
|
819
|
+
logger.info(error_str(f"Continuous assignment statement(e.g. 'a = b = 1') should be flatten before.",
|
|
820
|
+
child_node=node))
|
|
821
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
822
|
+
return
|
|
823
|
+
|
|
824
|
+
self.store_env()
|
|
825
|
+
self.stree = stree
|
|
826
|
+
self.ast_assign = node
|
|
827
|
+
self.node_manager = node_manager
|
|
828
|
+
value = node.value
|
|
829
|
+
if isinstance(value, ast.Call):
|
|
830
|
+
self.process_ast_call(value)
|
|
831
|
+
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
832
|
+
self.process_ast_mathops(value)
|
|
833
|
+
elif isinstance(value, ast.Subscript):
|
|
834
|
+
self.process_ast_subscript(value)
|
|
835
|
+
elif isinstance(value, (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str)):
|
|
836
|
+
self.process_ast_constant(value)
|
|
837
|
+
elif isinstance(value, (ast.Name, ast.Attribute)):
|
|
838
|
+
self.process_ast_name(value)
|
|
839
|
+
elif isinstance(value, (ast.Tuple, ast.List)):
|
|
840
|
+
self.process_ast_tuple(value)
|
|
841
|
+
elif isinstance(value, ast.Dict):
|
|
842
|
+
self.process_ast_dict(value)
|
|
843
|
+
else:
|
|
844
|
+
logger.info(f"ops-call({astunparse.unparse(node).strip()}) in assign will be supported in near feature, "
|
|
845
|
+
f"ignored as a python node now")
|
|
625
846
|
stree.try_append_python_node(node, node, node_manager)
|
|
847
|
+
self.restore_env()
|
|
626
848
|
|
|
627
849
|
|
|
628
850
|
g_assign_parser = reg_parser(AssignParser())
|