mindspore 2.2.11__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +7 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +76 -18
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +258 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +174 -62
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +142 -240
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +51 -24
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/common/__init__.py +15 -4
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +8 -9
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +411 -106
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +6 -8
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +260 -0
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +30 -11
- mindspore/common/recompute.py +262 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +272 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +468 -496
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +1140 -0
- mindspore/communication/management.py +118 -102
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +378 -65
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +163 -83
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +78 -59
- mindspore/dataset/engine/datasets_vision.py +77 -73
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +56 -38
- mindspore/dataset/engine/validators.py +11 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +8 -8
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +3 -3
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +14 -2
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +71 -127
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/dataset/vision.h +54 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +76 -58
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +53 -66
- mindspore/mindrecord/tools/cifar10_to_mr.py +48 -63
- mindspore/mindrecord/tools/csv_to_mr.py +7 -17
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +11 -21
- mindspore/mindrecord/tools/tfrecord_to_mr.py +2 -10
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/mint/__init__.py +1137 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +512 -0
- mindspore/mint/nn/functional.py +573 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +185 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +1 -0
- mindspore/nn/cell.py +213 -257
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +109 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/layer/activation.py +84 -94
- mindspore/nn/layer/basic.py +177 -82
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +103 -45
- mindspore/nn/layer/embedding_service.py +531 -0
- mindspore/nn/layer/embedding_service_layer.py +393 -0
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +18 -9
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +63 -84
- mindspore/nn/optim/ada_grad.py +6 -4
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +7 -4
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +58 -13
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +32 -9
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +6 -6
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +61 -67
- mindspore/numpy/utils.py +3 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +8 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -160
- mindspore/ops/_grad_experimental/grad_comm_ops.py +93 -36
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +92 -287
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/{cpu/concat.py → aicpu/generate_eod_mask.py} +16 -17
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +164 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +130 -58
- mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +980 -0
- mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +121 -23
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +53 -0
- mindspore/ops/extend/array_func.py +218 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +308 -0
- mindspore/ops/function/__init__.py +31 -11
- mindspore/ops/function/array_func.py +848 -1736
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +2 -5
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +27 -20
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +30 -53
- mindspore/ops/function/math_func.py +916 -2791
- mindspore/ops/function/nn_func.py +1445 -889
- mindspore/ops/function/other_func.py +6 -7
- mindspore/ops/function/parameter_func.py +6 -92
- mindspore/ops/function/random_func.py +254 -108
- mindspore/ops/function/reshard_func.py +102 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +11 -18
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +15 -14
- mindspore/ops/functional.py +342 -343
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +32 -23
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +21 -853
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +155 -511
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +3 -3
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +112 -2698
- mindspore/ops/operations/comm_ops.py +801 -118
- mindspore/ops/operations/custom_ops.py +62 -121
- mindspore/ops/operations/debug_ops.py +105 -36
- mindspore/ops/operations/image_ops.py +3 -219
- mindspore/ops/operations/inner_ops.py +54 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
- mindspore/ops/operations/math_ops.py +621 -4654
- mindspore/ops/operations/nn_ops.py +316 -2226
- mindspore/ops/operations/other_ops.py +53 -45
- mindspore/ops/operations/random_ops.py +4 -51
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +8 -8
- mindspore/ops/primitive.py +204 -103
- mindspore/ops/silent_check.py +162 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +250 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_ops.py +1084 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +968 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +354 -0
- mindspore/ops_generate/template.py +239 -0
- mindspore/parallel/__init__.py +7 -4
- mindspore/parallel/_auto_parallel_context.py +155 -6
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +62 -14
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +18 -9
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +10 -10
- mindspore/parallel/_utils.py +161 -6
- mindspore/parallel/algo_parameter_config.py +6 -8
- mindspore/parallel/checkpoint_transform.py +369 -64
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +344 -0
- mindspore/parallel/cluster/process_entity/_utils.py +126 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +128 -17
- mindspore/profiler/__init__.py +3 -2
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +125 -0
- mindspore/profiler/envprofiling.py +2 -2
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +27 -5
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +31 -280
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +151 -126
- mindspore/profiler/parser/ascend_msprof_generator.py +75 -274
- mindspore/profiler/parser/ascend_op_generator.py +94 -36
- mindspore/profiler/parser/ascend_timeline_generator.py +297 -131
- mindspore/profiler/parser/base_timeline_generator.py +17 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +11 -4
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +73 -4
- mindspore/profiler/parser/msadvisor_analyzer.py +5 -3
- mindspore/profiler/parser/msadvisor_parser.py +10 -4
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +522 -195
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +123 -37
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +46 -30
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -5
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +151 -37
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +238 -0
- mindspore/train/callback/_landscape.py +16 -11
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_mindio_ttp.py +443 -0
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +13 -14
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -21
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +298 -56
- mindspore/train/serialization.py +501 -221
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +56 -34
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
- mindspore-2.3.0.dist-info/RECORD +1400 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.11.dist-info/RECORD +0 -1920
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
mindspore/rewrite/node/node.py
CHANGED
|
@@ -13,23 +13,34 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Node class define of Rewrite. See detail in Node class docstring."""
|
|
16
|
-
from typing import Optional, Union
|
|
16
|
+
from typing import Optional, Union, List, Dict
|
|
17
17
|
import ast
|
|
18
18
|
import inspect
|
|
19
19
|
from types import FunctionType
|
|
20
|
+
import sys
|
|
20
21
|
|
|
21
22
|
from mindspore.nn import Cell
|
|
22
23
|
from mindspore.ops import Primitive
|
|
23
24
|
from mindspore import log as logger
|
|
24
|
-
from ... import _checkparam as Validator
|
|
25
|
-
from ..ast_helpers import AstModifier
|
|
26
25
|
from ..api.scoped_value import ScopedValue, ValueType
|
|
27
26
|
from ..api.node_type import NodeType
|
|
28
|
-
from ..namespace import is_subtree
|
|
29
|
-
from ..
|
|
30
|
-
from ..
|
|
27
|
+
from ..common.namespace import is_subtree
|
|
28
|
+
from ..common.error_log import error_str
|
|
29
|
+
from ..ast_helpers import AstModifier, AstReplacer, AstConverter
|
|
30
|
+
from ... import _checkparam as Validator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if sys.version_info >= (3, 9):
|
|
34
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
35
|
+
else:
|
|
36
|
+
import astunparse
|
|
31
37
|
|
|
32
|
-
|
|
38
|
+
|
|
39
|
+
class LocalPrim(Primitive):
|
|
40
|
+
"""This class is used to indicate a local primitive instance"""
|
|
41
|
+
def __init__(self, prim_obj: type):
|
|
42
|
+
super().__init__("rewrite_local_prim")
|
|
43
|
+
self.prim_obj = prim_obj
|
|
33
44
|
|
|
34
45
|
|
|
35
46
|
class Node:
|
|
@@ -63,7 +74,7 @@ class Node:
|
|
|
63
74
|
"""
|
|
64
75
|
|
|
65
76
|
def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
66
|
-
func_name: Optional[ScopedValue], args: [ScopedValue], kwargs:
|
|
77
|
+
func_name: Optional[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue], name: str,
|
|
67
78
|
instance):
|
|
68
79
|
"""
|
|
69
80
|
Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
|
|
@@ -77,7 +88,7 @@ class Node:
|
|
|
77
88
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
78
89
|
func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
|
|
79
90
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
80
|
-
kwargs (
|
|
91
|
+
kwargs (Dict[str, ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
81
92
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
82
93
|
Name of node also used as field name in network class.
|
|
83
94
|
instance: Object in network corresponding to this node.
|
|
@@ -90,7 +101,7 @@ class Node:
|
|
|
90
101
|
self._instance = instance
|
|
91
102
|
self._name = name
|
|
92
103
|
self._func_name: Optional[ScopedValue] = func_name
|
|
93
|
-
self._targets: [ScopedValue] = targets
|
|
104
|
+
self._targets: [ScopedValue] = targets if targets is not None else []
|
|
94
105
|
self._args_num = len(args) if args is not None else 0
|
|
95
106
|
self._kwargs_num = len(kwargs) if kwargs is not None else 0
|
|
96
107
|
self._normalized_args_keys = [] # for saving args' order
|
|
@@ -107,6 +118,10 @@ class Node:
|
|
|
107
118
|
self._arg_providers: {int: (Node, int)} = {}
|
|
108
119
|
# A dict that records which argument of which Node uses current Node's target
|
|
109
120
|
self._target_users: {int: [(Node, int)]} = {}
|
|
121
|
+
# Indicate this node represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
|
|
122
|
+
self._type_cls = None
|
|
123
|
+
# Indicate this node represent the initialize of a class type, e.g. abs_inst = P.Abs()
|
|
124
|
+
self._init_cls = None
|
|
110
125
|
|
|
111
126
|
@classmethod
|
|
112
127
|
def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
@@ -137,12 +152,6 @@ class Node:
|
|
|
137
152
|
raise RuntimeError("Input ast_node is None")
|
|
138
153
|
return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
|
|
139
154
|
|
|
140
|
-
@classmethod
|
|
141
|
-
def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
142
|
-
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = ""):
|
|
143
|
-
"""Create pass through node."""
|
|
144
|
-
return Node.create_call_method(ast_node, targets, PASS_THROUGH_METHOD, args, kwargs, name)
|
|
145
|
-
|
|
146
155
|
@classmethod
|
|
147
156
|
def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None):
|
|
148
157
|
"""
|
|
@@ -177,11 +186,11 @@ class Node:
|
|
|
177
186
|
else:
|
|
178
187
|
args = [default]
|
|
179
188
|
if ast_node is None:
|
|
180
|
-
ast_node = ast.arg(arg_name)
|
|
189
|
+
ast_node = ast.arg(arg_name, annotation="")
|
|
181
190
|
return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
|
|
182
191
|
|
|
183
192
|
@classmethod
|
|
184
|
-
def create_output_node(cls, ast_node: ast.AST,
|
|
193
|
+
def create_output_node(cls, ast_node: ast.AST, return_value: [ScopedValue], name: str = "return"):
|
|
185
194
|
"""
|
|
186
195
|
Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of
|
|
187
196
|
SymbolTree which is corresponding to return statement of forward function.
|
|
@@ -189,17 +198,14 @@ class Node:
|
|
|
189
198
|
Args:
|
|
190
199
|
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
191
200
|
return_values (list[str]): A list of string represents name of return values.
|
|
192
|
-
name (
|
|
193
|
-
Name of node also used as field name in network class.
|
|
201
|
+
name (ScopedValue): An instance of ScopedValue represents name of node.
|
|
194
202
|
"""
|
|
195
|
-
|
|
196
|
-
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
|
|
203
|
+
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), return_value, {},
|
|
197
204
|
name, None)
|
|
198
205
|
|
|
199
206
|
@classmethod
|
|
200
207
|
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
|
|
201
|
-
op_type: ScopedValue, args: [ScopedValue],
|
|
202
|
-
ops: {str: list}, name: str = ""):
|
|
208
|
+
op_type: ScopedValue, args: [ScopedValue], name: str = ""):
|
|
203
209
|
"""
|
|
204
210
|
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
|
|
205
211
|
A mathops node is used to represent a node with mathematical operations, such as
|
|
@@ -214,27 +220,11 @@ class Node:
|
|
|
214
220
|
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
|
|
215
221
|
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
216
222
|
sequentially in the list.
|
|
217
|
-
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
|
218
|
-
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
|
219
223
|
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
220
224
|
Name of node also used as field name in network class. The format of mathops node name
|
|
221
225
|
is 'AstNodeName_AstOpName_n'.
|
|
222
226
|
"""
|
|
223
|
-
return cls(NodeType.MathOps, ast_node, targets, op_type, args,
|
|
224
|
-
|
|
225
|
-
@staticmethod
|
|
226
|
-
def create_assign_node(targets, func_name, args, kwargs):
|
|
227
|
-
"""Create a ast.Assign type node."""
|
|
228
|
-
# create targets
|
|
229
|
-
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
230
|
-
# create call
|
|
231
|
-
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
232
|
-
ast_args = ast_creator_registry.get("Args")(args)
|
|
233
|
-
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
234
|
-
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
235
|
-
# create assign
|
|
236
|
-
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
237
|
-
return ast_node
|
|
227
|
+
return cls(NodeType.MathOps, ast_node, targets, op_type, args, None, name, None)
|
|
238
228
|
|
|
239
229
|
@staticmethod
|
|
240
230
|
def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
@@ -259,34 +249,30 @@ class Node:
|
|
|
259
249
|
if kwargs is None:
|
|
260
250
|
kwargs = {}
|
|
261
251
|
targets = Node._handle_targets(targets)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
_package = function.__globals__['__package__']
|
|
265
|
-
func_full_name = ".".join([_package, function.__name__]) if _package else function.__name__
|
|
266
|
-
func_scope = ''
|
|
267
|
-
func_name = func_full_name.split('.')[-1]
|
|
268
|
-
if func_full_name.count('.') > 0:
|
|
269
|
-
func_scope = func_full_name.rsplit('.')[0]
|
|
270
|
-
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
252
|
+
func_name = function.__name__
|
|
253
|
+
func_scope_name = ScopedValue.create_naming_value(func_name)
|
|
271
254
|
node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
|
|
272
255
|
return node
|
|
273
256
|
|
|
274
257
|
@classmethod
|
|
275
|
-
def inner_create_call_function(cls, node_name, ast_node, func_name
|
|
258
|
+
def inner_create_call_function(cls, node_name: str, ast_node: ast.Assign, func_name: ScopedValue, func_obj: object,
|
|
259
|
+
targets: List[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue]):
|
|
276
260
|
'''
|
|
277
261
|
Instantiate an instance of node whose type is `CallFunction`.
|
|
278
262
|
|
|
279
263
|
Args:
|
|
280
264
|
node_name (str): Name of node.
|
|
281
|
-
func_name (
|
|
265
|
+
func_name (ScopedValue): Name of function.
|
|
282
266
|
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
args (
|
|
286
|
-
kwargs (
|
|
267
|
+
func_obj (Object): An object of function. See detail in docstring of Node class.
|
|
268
|
+
targets (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
269
|
+
args (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
270
|
+
kwargs (Dict[str, ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
287
271
|
class.
|
|
288
272
|
'''
|
|
289
|
-
|
|
273
|
+
from . import CallFunction
|
|
274
|
+
# create CallFunction node
|
|
275
|
+
return CallFunction(targets, func_name, args, kwargs, node_name, ast_node, None, None, func_obj, False)
|
|
290
276
|
|
|
291
277
|
@staticmethod
|
|
292
278
|
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
@@ -329,7 +315,7 @@ class Node:
|
|
|
329
315
|
else:
|
|
330
316
|
func_name = node_name
|
|
331
317
|
if is_sub_net and is_subtree(op):
|
|
332
|
-
from ..
|
|
318
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
333
319
|
stb = SymbolTreeBuilder(op)
|
|
334
320
|
stree = stb.build()
|
|
335
321
|
replacer = AstReplacer(stree.get_class_ast())
|
|
@@ -401,7 +387,7 @@ class Node:
|
|
|
401
387
|
elif para.kind == inspect.Parameter.VAR_KEYWORD: # corresponds to a '**kwargs'
|
|
402
388
|
var_keyword_name = name
|
|
403
389
|
else:
|
|
404
|
-
raise RuntimeError("invalid
|
|
390
|
+
raise RuntimeError("invalid parameter kind:", para.kind)
|
|
405
391
|
if "self" in position_only_names:
|
|
406
392
|
position_only_names.remove("self")
|
|
407
393
|
if "self" in positional_or_keyword_names:
|
|
@@ -528,7 +514,11 @@ class Node:
|
|
|
528
514
|
results = []
|
|
529
515
|
for target in targets:
|
|
530
516
|
if isinstance(target, str):
|
|
531
|
-
|
|
517
|
+
scope = ""
|
|
518
|
+
name = target
|
|
519
|
+
if target.count('.') > 0:
|
|
520
|
+
scope, name = target.rsplit('.', 1)
|
|
521
|
+
results.append(ScopedValue.create_naming_value(name, scope))
|
|
532
522
|
elif isinstance(target, ScopedValue):
|
|
533
523
|
results.append(target)
|
|
534
524
|
else:
|
|
@@ -556,6 +546,22 @@ class Node:
|
|
|
556
546
|
attributes["cls"] = obj.__class__
|
|
557
547
|
return attributes
|
|
558
548
|
|
|
549
|
+
def get_type_cls(self) -> object:
|
|
550
|
+
"""Get the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
|
|
551
|
+
return self._type_cls
|
|
552
|
+
|
|
553
|
+
def set_type_cls(self, x):
|
|
554
|
+
"""Set the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
|
|
555
|
+
self._type_cls = x
|
|
556
|
+
|
|
557
|
+
def get_init_cls(self) -> object:
|
|
558
|
+
"""Get the class type object initialized by this node, e.g. abs_inst = P.Abs()"""
|
|
559
|
+
return self._init_cls
|
|
560
|
+
|
|
561
|
+
def set_init_cls(self, x):
|
|
562
|
+
"""Set the class type object initialized by this node"""
|
|
563
|
+
self._init_cls = x
|
|
564
|
+
|
|
559
565
|
def get_prev(self) -> 'Node':
|
|
560
566
|
"""
|
|
561
567
|
Get previous node of current node in source code order.
|
|
@@ -683,6 +689,22 @@ class Node:
|
|
|
683
689
|
inputs.append(arg_provider[0])
|
|
684
690
|
return inputs
|
|
685
691
|
|
|
692
|
+
def get_users(self) -> ['Node']:
|
|
693
|
+
"""
|
|
694
|
+
Get user nodes of current node in topological order.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
A list of instances of Node as user nodes.
|
|
698
|
+
"""
|
|
699
|
+
users = []
|
|
700
|
+
for target_users in self.get_target_users().values():
|
|
701
|
+
if not target_users:
|
|
702
|
+
continue
|
|
703
|
+
for (user, _) in target_users:
|
|
704
|
+
if user not in users:
|
|
705
|
+
users.append(user)
|
|
706
|
+
return users
|
|
707
|
+
|
|
686
708
|
def get_targets(self) -> [ScopedValue]:
|
|
687
709
|
"""
|
|
688
710
|
Getter of _targets.
|
|
@@ -748,7 +770,7 @@ class Node:
|
|
|
748
770
|
"""
|
|
749
771
|
self._func_name = func_name
|
|
750
772
|
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
|
|
751
|
-
self.
|
|
773
|
+
self._sync_assign_func_name_to_ast()
|
|
752
774
|
|
|
753
775
|
def get_name(self) -> str:
|
|
754
776
|
"""
|
|
@@ -789,6 +811,10 @@ class Node:
|
|
|
789
811
|
Returns:
|
|
790
812
|
A type.
|
|
791
813
|
"""
|
|
814
|
+
if isinstance(self._instance, LocalPrim):
|
|
815
|
+
return self._instance.prim_obj
|
|
816
|
+
if inspect.isfunction(self._instance):
|
|
817
|
+
return self._instance
|
|
792
818
|
return type(self._instance)
|
|
793
819
|
|
|
794
820
|
def get_instance(self):
|
|
@@ -824,7 +850,7 @@ class Node:
|
|
|
824
850
|
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
825
851
|
if out_idx is None:
|
|
826
852
|
if len(node.get_targets()) != 1:
|
|
827
|
-
raise
|
|
853
|
+
raise ValueError("node should has one output when out_idx is not provided")
|
|
828
854
|
out_idx = 0
|
|
829
855
|
Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
|
|
830
856
|
new_arg = node.get_targets()[out_idx]
|
|
@@ -1076,6 +1102,33 @@ class Node:
|
|
|
1076
1102
|
self.set_ast(ast_assign)
|
|
1077
1103
|
return ast_assign
|
|
1078
1104
|
|
|
1105
|
+
def get_source_code(self) -> str:
|
|
1106
|
+
"""Get source code of node from ast of node."""
|
|
1107
|
+
return astunparse.unparse(self._ast_node).strip()
|
|
1108
|
+
|
|
1109
|
+
def append_kwarg(self, kwarg: Dict[str, ScopedValue]):
|
|
1110
|
+
"""
|
|
1111
|
+
Append a new keyword arg to node.
|
|
1112
|
+
|
|
1113
|
+
Args:
|
|
1114
|
+
kwarg (Dict[str, ScopedValue]): The new keyword arg.
|
|
1115
|
+
|
|
1116
|
+
"""
|
|
1117
|
+
if self.get_node_type() not in [NodeType.Tree, NodeType.CallFunction]:
|
|
1118
|
+
raise TypeError(f"For append_new_kwarg, the type of node can only be one of [Tree, CallFunction], "
|
|
1119
|
+
f"but got {self.get_node_type()}")
|
|
1120
|
+
Validator.check_element_type_of_dict("kwarg", kwarg, [str], [ScopedValue], "append_new_kwarg")
|
|
1121
|
+
for arg_key, value in kwarg.items():
|
|
1122
|
+
# add keyword into _normalized_args
|
|
1123
|
+
self._normalized_args[arg_key] = value
|
|
1124
|
+
self._normalized_args_keys.append(arg_key)
|
|
1125
|
+
self._kwargs_num += 1
|
|
1126
|
+
# add keyword ast into ast.Call
|
|
1127
|
+
ast_assign: ast.Assign = self._ast_node
|
|
1128
|
+
ast_call: ast.Call = ast_assign.value
|
|
1129
|
+
new_keyword = ast.keyword(arg=arg_key, value=AstModifier.get_ast_by_value(value, None))
|
|
1130
|
+
ast_call.keywords.append(new_keyword)
|
|
1131
|
+
|
|
1079
1132
|
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
|
1080
1133
|
"""
|
|
1081
1134
|
Merge args and kwargs to normalized args.
|
|
@@ -1097,7 +1150,7 @@ class Node:
|
|
|
1097
1150
|
if not kwargs:
|
|
1098
1151
|
kwargs = {}
|
|
1099
1152
|
normalized_args: dict = dict()
|
|
1100
|
-
if self._instance and hasattr(type(self._instance), "construct"):
|
|
1153
|
+
if (args or kwargs) and self._instance and hasattr(type(self._instance), "construct"):
|
|
1101
1154
|
parameters = inspect.signature(type(self._instance).construct).parameters
|
|
1102
1155
|
names = Node._get_construct_arg_names(parameters)
|
|
1103
1156
|
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
|
|
@@ -1116,12 +1169,9 @@ class Node:
|
|
|
1116
1169
|
self._normalized_args_keys.append(arg_key)
|
|
1117
1170
|
return normalized_args
|
|
1118
1171
|
|
|
1119
|
-
##########################################################################################################
|
|
1120
1172
|
# Synchronize rewrite node args to ast node
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
def _sync_assign_func_to_ast(self):
|
|
1124
|
-
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
|
1173
|
+
def _sync_assign_func_name_to_ast(self):
|
|
1174
|
+
"""Sync func_name of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
|
1125
1175
|
if self._ast_node is None:
|
|
1126
1176
|
return
|
|
1127
1177
|
assign_ast = self._ast_node
|
|
@@ -1130,18 +1180,20 @@ class Node:
|
|
|
1130
1180
|
call_ast = assign_ast.value
|
|
1131
1181
|
if not isinstance(call_ast, ast.Call):
|
|
1132
1182
|
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
|
1183
|
+
if self._func_name.type == ValueType.UnsupportedValue:
|
|
1184
|
+
return
|
|
1133
1185
|
func_ast = call_ast.func
|
|
1134
|
-
if not self._func_name.
|
|
1186
|
+
if not self._func_name.scope:
|
|
1135
1187
|
if isinstance(func_ast, ast.Name):
|
|
1136
1188
|
func_ast.id = self._func_name.value
|
|
1137
1189
|
else:
|
|
1138
1190
|
call_ast.func = ast.Name(self._func_name.value, ast.Store())
|
|
1139
1191
|
else:
|
|
1140
1192
|
if isinstance(func_ast, ast.Attribute):
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1193
|
+
if not isinstance(func_ast.value, ast.Name):
|
|
1194
|
+
func_ast.value = ast.Name(self._func_name.scope, ast.Load())
|
|
1195
|
+
else:
|
|
1196
|
+
func_ast.value.id = self._func_name.scope
|
|
1145
1197
|
func_ast.attr = self._func_name.value
|
|
1146
1198
|
else:
|
|
1147
1199
|
call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
|
|
@@ -1154,45 +1206,42 @@ class Node:
|
|
|
1154
1206
|
return
|
|
1155
1207
|
assign_ast = self._ast_node
|
|
1156
1208
|
if not isinstance(assign_ast, ast.Assign):
|
|
1157
|
-
raise TypeError("assign_ast should be ast.Assign, got:
|
|
1209
|
+
raise TypeError(error_str(f"assign_ast should be ast.Assign, but got: {type(assign_ast)}",
|
|
1210
|
+
father_node=assign_ast))
|
|
1158
1211
|
# update targets
|
|
1159
|
-
|
|
1160
|
-
if
|
|
1161
|
-
raise
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
target_ast.id = target.value
|
|
1169
|
-
elif isinstance(target_ast, ast.Tuple):
|
|
1170
|
-
if not isinstance(target_ast.elts[i], ast.Name):
|
|
1171
|
-
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
|
|
1172
|
-
target_ast.elts[i].id = target.value
|
|
1173
|
-
else:
|
|
1174
|
-
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
|
|
1175
|
-
target_ast.id = target.value
|
|
1176
|
-
ast.fix_missing_locations(assign_ast)
|
|
1177
|
-
|
|
1178
|
-
def _sync_call_cell_args_to_ast(self):
|
|
1179
|
-
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell or CallPrimitive."""
|
|
1212
|
+
target_ast_elems = AstConverter.get_ast_target_elems(assign_ast.targets[0])
|
|
1213
|
+
if len(self._targets) != len(target_ast_elems):
|
|
1214
|
+
raise ValueError(error_str(f"The number of targets should be {len(target_ast_elems)}, "
|
|
1215
|
+
f"but got {len(self._targets)}", father_node=assign_ast))
|
|
1216
|
+
for i, target_ast in enumerate(target_ast_elems):
|
|
1217
|
+
target_ast_elems[i] = AstModifier.get_ast_by_value(self._targets[i], target_ast)
|
|
1218
|
+
|
|
1219
|
+
def _sync_call_args_to_ast(self):
|
|
1220
|
+
"""Sync args of ast.Call from self._normalized_args."""
|
|
1180
1221
|
if self._ast_node is None:
|
|
1181
1222
|
return
|
|
1182
1223
|
assign_ast = self._ast_node
|
|
1183
1224
|
if not isinstance(assign_ast, ast.Assign):
|
|
1184
|
-
raise TypeError(f"
|
|
1225
|
+
raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node should be "
|
|
1226
|
+
f"ast.Assign, but got: {type(assign_ast)}")
|
|
1185
1227
|
assign_value = assign_ast.value
|
|
1186
1228
|
if not isinstance(assign_value, ast.Call):
|
|
1187
|
-
|
|
1229
|
+
if isinstance(assign_value, ast.Attribute) and self._node_type in (NodeType.CellContainer,
|
|
1230
|
+
NodeType.CallCell):
|
|
1231
|
+
# CellContainers in control flow may be flatten to ast.Attribute: blocks_var = self.blocks
|
|
1232
|
+
# In this case, no args exist in node, so we don't need to sync.
|
|
1233
|
+
# CellContainers may be type of CallCell when share one implementation
|
|
1234
|
+
return
|
|
1235
|
+
raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node.value should "
|
|
1236
|
+
f"be ast.Call, but got: {type(assign_value)}")
|
|
1188
1237
|
keywords_ast = assign_value.keywords
|
|
1189
1238
|
args_ast = assign_value.args
|
|
1190
1239
|
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
|
|
1191
|
-
raise
|
|
1192
|
-
|
|
1240
|
+
raise ValueError("ast keywords plus args len is not equal to self._normalized_args value")
|
|
1193
1241
|
for arg_index in range(self._args_num):
|
|
1194
1242
|
arg_ast = args_ast[arg_index]
|
|
1195
|
-
|
|
1243
|
+
args_ast[arg_index] = \
|
|
1244
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
|
|
1196
1245
|
|
|
1197
1246
|
# the order of kwargs may not the same as that in keywords_ast
|
|
1198
1247
|
keyword_map_index = {}
|
|
@@ -1200,117 +1249,61 @@ class Node:
|
|
|
1200
1249
|
keyword_map_index[keyword_ast.arg] = index
|
|
1201
1250
|
for keyword_index in range(self._kwargs_num):
|
|
1202
1251
|
key = self._normalized_args_keys[keyword_index + self._args_num]
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
def _sync_call_pass_through_method_args_to_ast(self, assign_value):
|
|
1207
|
-
"""
|
|
1208
|
-
Sync args of PASS_THROUGH_METHOD type ast.Cell of ast.Assign from self._normalized_args when NodeType is
|
|
1209
|
-
CallMethod.
|
|
1210
|
-
"""
|
|
1211
|
-
if isinstance(assign_value, ast.Name):
|
|
1212
|
-
if len(self._normalized_args_keys) != 1:
|
|
1213
|
-
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1214
|
-
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1215
|
-
if arg.type != ValueType.NamingValue:
|
|
1216
|
-
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
|
1217
|
-
if arg.scope != "":
|
|
1218
|
-
raise RuntimeError("arg.scope should be empty")
|
|
1219
|
-
assign_value.id = arg.value
|
|
1220
|
-
elif isinstance(assign_value, ast.Attribute):
|
|
1221
|
-
if len(self._normalized_args_keys) != 1:
|
|
1222
|
-
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1223
|
-
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1224
|
-
if arg.type != ValueType.NamingValue:
|
|
1225
|
-
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
|
1226
|
-
assign_value.attr = arg.value
|
|
1227
|
-
assign_value_value = assign_value.value
|
|
1228
|
-
if not isinstance(assign_value_value, ast.Name):
|
|
1229
|
-
raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value))
|
|
1230
|
-
assign_value_value.id = arg.scope
|
|
1231
|
-
else:
|
|
1232
|
-
if len(self._normalized_args_keys) != 1:
|
|
1233
|
-
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1234
|
-
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1235
|
-
if arg.type != ValueType.ConstantValue:
|
|
1236
|
-
raise RuntimeError("arg should be an ConstantValue")
|
|
1237
|
-
if arg.scope != "":
|
|
1238
|
-
raise RuntimeError("arg.scope should be empty")
|
|
1239
|
-
assign_value.value = arg.value
|
|
1252
|
+
keywords_ast[keyword_map_index.get(key)].value = \
|
|
1253
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(key),
|
|
1254
|
+
keywords_ast[keyword_map_index.get(key)].value)
|
|
1240
1255
|
|
|
1241
1256
|
def _sync_call_method_args_to_ast(self):
|
|
1242
1257
|
"""
|
|
1243
1258
|
Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
|
|
1244
|
-
|
|
1245
1259
|
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1260
|
+
| func_name | data_type | value of ast.Assign |
|
|
1261
|
+
|:---------------|:------------|:------------------------|
|
|
1262
|
+
| 'pass_through' | constants | ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str |
|
|
1263
|
+
| 'pass_through' | variables | ast.Name, ast.Attribute |
|
|
1264
|
+
| 'tuple' | tuple | ast.Tuple |
|
|
1265
|
+
| 'list' | list | ast.List |
|
|
1266
|
+
| 'dict' | dict | ast.Dict |
|
|
1250
1267
|
"""
|
|
1251
1268
|
if self._ast_node is None:
|
|
1252
1269
|
return
|
|
1253
1270
|
assign_ast = self._ast_node
|
|
1254
1271
|
if not isinstance(assign_ast, ast.Assign):
|
|
1255
|
-
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1272
|
+
raise TypeError(f"For node '{self.get_name()}', assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1256
1273
|
assign_value = assign_ast.value
|
|
1257
|
-
if self._func_name ==
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1274
|
+
if self._func_name.value == "pass_through":
|
|
1275
|
+
# update constants/variables
|
|
1276
|
+
assign_ast.value = \
|
|
1277
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), assign_value)
|
|
1278
|
+
elif self._func_name.value in ("tuple", "list", "dict"):
|
|
1279
|
+
# update tuple/list/dict
|
|
1280
|
+
ast_elts = assign_value.values if isinstance(assign_value, ast.Dict) else assign_value.elts
|
|
1281
|
+
if len(self._normalized_args_keys) != len(ast_elts):
|
|
1282
|
+
raise ValueError(f"For node '{self.get_name()}', size of self._normalized_args_keys"
|
|
1283
|
+
f"({len(self._normalized_args_keys)}) should be equal to size of elements of "
|
|
1284
|
+
f"ast_elts({len(ast_elts)})")
|
|
1285
|
+
for index, elt in enumerate(ast_elts):
|
|
1264
1286
|
scoped_value: ScopedValue = self._normalized_args.get(self._normalized_args_keys[index])
|
|
1265
|
-
|
|
1266
|
-
elt.value = scoped_value.value
|
|
1267
|
-
elif isinstance(elt, (ast.Str, ast.Bytes)):
|
|
1268
|
-
elt.s = scoped_value.value
|
|
1269
|
-
elif isinstance(elt, ast.Num):
|
|
1270
|
-
elt.n = scoped_value.value
|
|
1271
|
-
elif isinstance(elt, ast.Name):
|
|
1272
|
-
elt.id = scoped_value.value
|
|
1273
|
-
elif isinstance(elt, ast.Attribute) and isinstance(elt.value, ast.Name):
|
|
1274
|
-
elt.value.id = scoped_value.scope
|
|
1275
|
-
elt.attr = scoped_value.value
|
|
1276
|
-
else:
|
|
1277
|
-
raise RuntimeError("Only support constant or symbol in tuple now")
|
|
1287
|
+
ast_elts[index] = AstModifier.get_ast_by_value(scoped_value, elt)
|
|
1278
1288
|
else:
|
|
1279
|
-
raise
|
|
1289
|
+
raise TypeError(f"For node '{self.get_name()}', only support (pass_through, tuple or dict method) as "
|
|
1290
|
+
f"call_method, but got {self._func_name.value}")
|
|
1280
1291
|
|
|
1281
1292
|
def _sync_return_node_to_ast(self):
|
|
1282
1293
|
"""
|
|
1283
1294
|
Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
|
|
1284
1295
|
|
|
1285
1296
|
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1286
|
-
|
|
1287
|
-
- ast.Tuple
|
|
1297
|
+
(ast.Name, ast.Attribute)
|
|
1288
1298
|
"""
|
|
1289
1299
|
if self._ast_node is None:
|
|
1290
1300
|
return
|
|
1291
1301
|
return_ast = self._ast_node
|
|
1292
1302
|
if not isinstance(return_ast, ast.Return):
|
|
1293
|
-
raise TypeError("return_ast should be ast.Return, got:
|
|
1294
|
-
# update args
|
|
1303
|
+
raise TypeError(f"For node '{self.get_name()}', return_ast should be ast.Return, got: {type(return_ast)}")
|
|
1295
1304
|
return_value_ast = return_ast.value
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1299
|
-
return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value
|
|
1300
|
-
elif isinstance(return_value_ast, ast.Tuple):
|
|
1301
|
-
elements = return_value_ast.elts
|
|
1302
|
-
if len(self._normalized_args.values()) != len(elements):
|
|
1303
|
-
raise RuntimeError("self._normalized_args.values() should have elements same length")
|
|
1304
|
-
for elt_index, elt in enumerate(elements):
|
|
1305
|
-
if not isinstance(elt, ast.Name):
|
|
1306
|
-
raise RuntimeError("Only support ast.Name as return value: ", elt)
|
|
1307
|
-
arg = self._normalized_args.get(self._normalized_args_keys[elt_index])
|
|
1308
|
-
if not isinstance(arg, ScopedValue):
|
|
1309
|
-
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
1310
|
-
elt.id = arg.value
|
|
1311
|
-
else:
|
|
1312
|
-
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
|
1313
|
-
ast.fix_missing_locations(return_ast)
|
|
1305
|
+
return_ast.value = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1306
|
+
return_value_ast)
|
|
1314
1307
|
|
|
1315
1308
|
def _sync_mathops_node_args_to_ast(self):
|
|
1316
1309
|
"""
|
|
@@ -1324,44 +1317,67 @@ class Node:
|
|
|
1324
1317
|
if isinstance(mathops_node, ast.BinOp):
|
|
1325
1318
|
left = mathops_node.left
|
|
1326
1319
|
right = mathops_node.right
|
|
1327
|
-
AstModifier.
|
|
1328
|
-
|
|
1320
|
+
mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1321
|
+
left)
|
|
1322
|
+
mathops_node.right = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[1]),
|
|
1323
|
+
right)
|
|
1329
1324
|
elif isinstance(mathops_node, ast.UnaryOp):
|
|
1330
1325
|
operand = mathops_node.operand
|
|
1331
|
-
|
|
1326
|
+
mathops_node.operand = \
|
|
1327
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
|
|
1332
1328
|
elif isinstance(mathops_node, ast.BoolOp):
|
|
1333
1329
|
values = mathops_node.values
|
|
1334
1330
|
for arg_index in range(self._args_num):
|
|
1335
1331
|
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1336
|
-
AstModifier.
|
|
1332
|
+
values[arg_index] = AstModifier.get_ast_by_value(arg_value, values[arg_index])
|
|
1337
1333
|
elif isinstance(mathops_node, ast.Compare):
|
|
1338
1334
|
left = mathops_node.left
|
|
1339
|
-
AstModifier.
|
|
1335
|
+
mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1336
|
+
left)
|
|
1340
1337
|
comparators = mathops_node.comparators
|
|
1341
1338
|
for arg_index in range(1, self._args_num):
|
|
1342
1339
|
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1343
|
-
AstModifier.
|
|
1340
|
+
comparators[arg_index - 1] = AstModifier.get_ast_by_value(arg_value, comparators[arg_index - 1])
|
|
1344
1341
|
else:
|
|
1345
1342
|
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
1346
1343
|
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
|
|
1347
1344
|
|
|
1345
|
+
def _sync_control_flow_args_to_ast(self):
|
|
1346
|
+
"""
|
|
1347
|
+
Sync values from self._normalized_args to the ast node of control flow.
|
|
1348
|
+
"""
|
|
1349
|
+
if self._ast_node is None:
|
|
1350
|
+
return
|
|
1351
|
+
normalized_args_num = len(self._normalized_args_keys)
|
|
1352
|
+
if normalized_args_num == 0:
|
|
1353
|
+
return
|
|
1354
|
+
if normalized_args_num > 1:
|
|
1355
|
+
raise ValueError("self._normalized_args_keys should have less than 1 elements")
|
|
1356
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1357
|
+
if isinstance(self._ast_node, (ast.If, ast.IfExp, ast.While)):
|
|
1358
|
+
self._ast_node.test = AstModifier.get_ast_by_value(arg_value, self._ast_node.test)
|
|
1359
|
+
elif isinstance(self._ast_node, ast.For):
|
|
1360
|
+
self._ast_node.iter = AstModifier.get_ast_by_value(arg_value, self._ast_node.iter)
|
|
1361
|
+
else:
|
|
1362
|
+
raise ValueError(f"For Control Flow, ast_node should be one of [ast.If, ast.IfExp, "
|
|
1363
|
+
f"ast.While, ast.For], but got {type(self._ast_node)}")
|
|
1364
|
+
|
|
1348
1365
|
def _sync_arg(self):
|
|
1349
1366
|
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1350
1367
|
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
|
|
1351
1368
|
NodeType.CellContainer, NodeType.CallFunction):
|
|
1352
|
-
self.
|
|
1369
|
+
self._sync_call_args_to_ast()
|
|
1353
1370
|
elif self._node_type == NodeType.Output:
|
|
1354
1371
|
self._sync_return_node_to_ast()
|
|
1355
1372
|
elif self._node_type == NodeType.CallMethod:
|
|
1356
1373
|
self._sync_call_method_args_to_ast()
|
|
1357
1374
|
elif self._node_type == NodeType.MathOps:
|
|
1358
1375
|
self._sync_mathops_node_args_to_ast()
|
|
1376
|
+
elif self._node_type == NodeType.ControlFlow:
|
|
1377
|
+
self._sync_control_flow_args_to_ast()
|
|
1359
1378
|
|
|
1360
1379
|
|
|
1361
|
-
##########################################################################################################
|
|
1362
1380
|
# Child classes
|
|
1363
|
-
##########################################################################################################
|
|
1364
|
-
|
|
1365
1381
|
class TreeNode(Node):
|
|
1366
1382
|
"""Tree type Node who holds a handler of SymbolTree."""
|
|
1367
1383
|
|