mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,18 @@
|
|
|
16
16
|
from typing import Optional, Union
|
|
17
17
|
import ast
|
|
18
18
|
import inspect
|
|
19
|
+
from types import FunctionType
|
|
19
20
|
|
|
20
21
|
from mindspore.nn import Cell
|
|
21
22
|
from mindspore.ops import Primitive
|
|
22
23
|
from mindspore import log as logger
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from
|
|
24
|
+
from ... import _checkparam as Validator
|
|
25
|
+
from ..ast_helpers import AstModifier
|
|
26
|
+
from ..api.scoped_value import ScopedValue, ValueType
|
|
27
|
+
from ..api.node_type import NodeType
|
|
28
|
+
from ..namespace import is_subtree
|
|
29
|
+
from ..ast_helpers.ast_replacer import AstReplacer
|
|
30
|
+
from ..ast_creator_register import ast_creator_registry
|
|
29
31
|
|
|
30
32
|
PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
|
|
31
33
|
|
|
@@ -36,35 +38,33 @@ class Node:
|
|
|
36
38
|
invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
|
|
37
39
|
Node has different meaning in different type of node:
|
|
38
40
|
|
|
39
|
-
- CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
|
|
40
|
-
is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
|
|
41
|
-
corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
|
|
42
|
-
corresponding to func of call expression which means symbol of the cell-op.
|
|
41
|
+
- CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
|
|
42
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
|
|
43
|
+
`kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
|
|
44
|
+
method. `func` is corresponding to func of call expression which means symbol of the cell-op.
|
|
43
45
|
- CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
|
|
44
|
-
`targets`, `args`, `kwargs` and `
|
|
46
|
+
`targets`, `args`, `kwargs` and `func_name` are as previous.
|
|
45
47
|
- CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
|
|
46
|
-
`targets` is corresponding to targets of ast.Assign which means return values of this method. `
|
|
47
|
-
the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
|
|
48
|
-
value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
|
|
49
|
-
CallMethod node whose `
|
|
50
|
-
- GetAttr: retrieves a parameter from the SymbolTree hierarchy. `func` represents which parameter in SymbolTree
|
|
51
|
-
hierarchy. `targets` is corresponding to targets of ast.Assign which means what symbol to accept the result of
|
|
52
|
-
get-attr. `args` and `kwargs` are don't-care.
|
|
48
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
|
|
49
|
+
represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
|
|
50
|
+
method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
|
|
51
|
+
mapped to CallMethod node whose `func_name` is "PassThrough".
|
|
53
52
|
- Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
|
|
54
|
-
supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `
|
|
53
|
+
supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
|
|
55
54
|
- Input: an input node represents an input of current network which also a parameter of forward method of Cell.
|
|
56
55
|
`targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
|
|
57
|
-
of forward function. `kwargs` and `
|
|
56
|
+
of forward function. `kwargs` and `func_name` are don't-care.
|
|
58
57
|
- Output: an output node represents the output of current network which is corresponding to return statement of
|
|
59
|
-
forward method of Cell. `args` represents return values. `
|
|
60
|
-
don't-care.
|
|
58
|
+
forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
|
|
59
|
+
`kwargs` are don't-care.
|
|
61
60
|
- Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
|
|
62
|
-
`targets`, `args`, `kwargs` and `
|
|
63
|
-
instance.
|
|
61
|
+
`targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
|
|
62
|
+
SymbolTree instance.
|
|
64
63
|
"""
|
|
65
64
|
|
|
66
65
|
def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
67
|
-
|
|
66
|
+
func_name: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str,
|
|
67
|
+
instance):
|
|
68
68
|
"""
|
|
69
69
|
Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
|
|
70
70
|
as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
|
|
@@ -75,7 +75,7 @@ class Node:
|
|
|
75
75
|
ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
|
|
76
76
|
not be None except when node type is Unknown.
|
|
77
77
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
78
|
-
|
|
78
|
+
func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
|
|
79
79
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
80
80
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
81
81
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
@@ -89,58 +89,29 @@ class Node:
|
|
|
89
89
|
self._attribute = Node._get_cell_or_prim_op_attribute(instance)
|
|
90
90
|
self._instance = instance
|
|
91
91
|
self._name = name
|
|
92
|
-
self.
|
|
92
|
+
self._func_name: Optional[ScopedValue] = func_name
|
|
93
93
|
self._targets: [ScopedValue] = targets
|
|
94
94
|
self._args_num = len(args) if args is not None else 0
|
|
95
95
|
self._kwargs_num = len(kwargs) if kwargs is not None else 0
|
|
96
96
|
self._normalized_args_keys = [] # for saving args' order
|
|
97
97
|
self._normalized_args = self._get_normalized_args(args, kwargs)
|
|
98
|
-
# edge of node
|
|
99
|
-
self._inputs: [Node] = []
|
|
100
98
|
# position in graph nodes list
|
|
101
99
|
# it will affect code-order of python code
|
|
102
100
|
self._prev: Optional[Node] = None
|
|
103
101
|
self._next: Optional[Node] = None
|
|
104
102
|
# A handler of SymbolTree current node belonging to
|
|
105
103
|
self._belong_tree = None
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
113
|
-
A `CallCell` node represents an invoking to cell-op.
|
|
114
|
-
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
118
|
-
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
119
|
-
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
120
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
121
|
-
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
122
|
-
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
123
|
-
class.
|
|
124
|
-
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
125
|
-
Name of node also used as field name in network class.
|
|
126
|
-
"""
|
|
127
|
-
|
|
128
|
-
if not isinstance(op, (Cell, Primitive)):
|
|
129
|
-
raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
|
|
130
|
-
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
131
|
-
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
132
|
-
if ast_node is None:
|
|
133
|
-
ast_node = AstModifier.create_call_assign(targets, func, non_custom_args, non_custom_kwargs)
|
|
134
|
-
if isinstance(op, Cell):
|
|
135
|
-
node_type = NodeType.CallCell
|
|
136
|
-
else:
|
|
137
|
-
node_type = NodeType.CallPrimitive
|
|
138
|
-
return cls(node_type, ast_node, targets, func, args, kwargs, name, op)
|
|
104
|
+
# A handler of NodeManager current node belonging to
|
|
105
|
+
self._node_manager = None
|
|
106
|
+
# A dict that records which target of which Node current Node's argument come from
|
|
107
|
+
self._arg_providers: {int: (Node, int)} = {}
|
|
108
|
+
# A dict that records which argument of which Node uses current Node's target
|
|
109
|
+
self._target_users: {int: [(Node, int)]} = {}
|
|
139
110
|
|
|
140
111
|
@classmethod
|
|
141
112
|
def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
142
|
-
|
|
143
|
-
name: str = ""):
|
|
113
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
|
|
114
|
+
kwargs: {str: ScopedValue}=None, name: str = ""):
|
|
144
115
|
"""
|
|
145
116
|
Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
|
|
146
117
|
invoking to cell-op.
|
|
@@ -149,7 +120,7 @@ class Node:
|
|
|
149
120
|
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
|
|
150
121
|
should not be None currently.
|
|
151
122
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
152
|
-
|
|
123
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
153
124
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
154
125
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
155
126
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
@@ -159,12 +130,12 @@ class Node:
|
|
|
159
130
|
args = []
|
|
160
131
|
if kwargs is None:
|
|
161
132
|
kwargs = {}
|
|
162
|
-
if isinstance(
|
|
163
|
-
|
|
133
|
+
if isinstance(func_name, str):
|
|
134
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
164
135
|
new_targets = Node._handle_targets(targets)
|
|
165
136
|
if ast_node is None:
|
|
166
137
|
raise RuntimeError("Input ast_node is None")
|
|
167
|
-
return cls(NodeType.CallMethod, ast_node, new_targets,
|
|
138
|
+
return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
|
|
168
139
|
|
|
169
140
|
@classmethod
|
|
170
141
|
def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
@@ -187,7 +158,8 @@ class Node:
|
|
|
187
158
|
return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
|
|
188
159
|
|
|
189
160
|
@classmethod
|
|
190
|
-
def create_input_node(cls, ast_node: ast.AST, arg_name: str, default: Optional[ScopedValue] = None,
|
|
161
|
+
def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
|
|
162
|
+
name: str = ""):
|
|
191
163
|
"""
|
|
192
164
|
Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
|
|
193
165
|
SymbolTree which is corresponding to parameters of forward function.
|
|
@@ -204,6 +176,8 @@ class Node:
|
|
|
204
176
|
args = []
|
|
205
177
|
else:
|
|
206
178
|
args = [default]
|
|
179
|
+
if ast_node is None:
|
|
180
|
+
ast_node = ast.arg(arg_name)
|
|
207
181
|
return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
|
|
208
182
|
|
|
209
183
|
@classmethod
|
|
@@ -241,17 +215,83 @@ class Node:
|
|
|
241
215
|
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
242
216
|
sequentially in the list.
|
|
243
217
|
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
|
244
|
-
|
|
218
|
+
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
|
245
219
|
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
246
220
|
Name of node also used as field name in network class. The format of mathops node name
|
|
247
221
|
is 'AstNodeName_AstOpName_n'.
|
|
248
222
|
"""
|
|
249
223
|
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
|
|
250
224
|
|
|
225
|
+
@staticmethod
|
|
226
|
+
def create_assign_node(targets, func_name, args, kwargs):
|
|
227
|
+
"""Create a ast.Assign type node."""
|
|
228
|
+
# create targets
|
|
229
|
+
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
230
|
+
# create call
|
|
231
|
+
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
232
|
+
ast_args = ast_creator_registry.get("Args")(args)
|
|
233
|
+
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
234
|
+
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
235
|
+
# create assign
|
|
236
|
+
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
237
|
+
return ast_node
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
241
|
+
kwargs: {str: ScopedValue}=None):
|
|
242
|
+
"""
|
|
243
|
+
Create a node that corresponds to a function call.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
function (FunctionType): The function to be called.
|
|
247
|
+
targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
|
|
248
|
+
args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
249
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
250
|
+
kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
251
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
252
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
An instance of `Node`.
|
|
256
|
+
"""
|
|
257
|
+
if args is None:
|
|
258
|
+
args = []
|
|
259
|
+
if kwargs is None:
|
|
260
|
+
kwargs = {}
|
|
261
|
+
targets = Node._handle_targets(targets)
|
|
262
|
+
_package = None
|
|
263
|
+
if isinstance(function, FunctionType):
|
|
264
|
+
_package = function.__globals__['__package__']
|
|
265
|
+
func_full_name = ".".join([_package, function.__name__]) if _package else function.__name__
|
|
266
|
+
func_scope = ''
|
|
267
|
+
func_name = func_full_name.split('.')[-1]
|
|
268
|
+
if func_full_name.count('.') > 0:
|
|
269
|
+
func_scope = func_full_name.rsplit('.')[0]
|
|
270
|
+
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
271
|
+
node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
|
|
272
|
+
return node
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def inner_create_call_function(cls, node_name, ast_node, func_name, function, targets, args, kwargs):
|
|
276
|
+
'''
|
|
277
|
+
Instantiate an instance of node whose type is `CallFunction`.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
node_name (str): Name of node.
|
|
281
|
+
func_name (str): Name of function.
|
|
282
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
283
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
284
|
+
function (Object): An instance of function. See detail in docstring of Node class.
|
|
285
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
286
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
287
|
+
class.
|
|
288
|
+
'''
|
|
289
|
+
return cls(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, function)
|
|
290
|
+
|
|
251
291
|
@staticmethod
|
|
252
292
|
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
253
|
-
|
|
254
|
-
|
|
293
|
+
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
|
|
294
|
+
is_sub_net: bool = False):
|
|
255
295
|
"""
|
|
256
296
|
Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
257
297
|
If op is custom defined, it is treated by TreeNode.
|
|
@@ -262,12 +302,11 @@ class Node:
|
|
|
262
302
|
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
263
303
|
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
264
304
|
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
265
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
266
305
|
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
267
306
|
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
268
307
|
class.
|
|
269
|
-
|
|
270
|
-
Name of node also used as field name in network class.
|
|
308
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
309
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
271
310
|
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
|
|
272
311
|
`cell` to a TreeNode, else a CallCell Node. Default is a False.
|
|
273
312
|
"""
|
|
@@ -275,29 +314,58 @@ class Node:
|
|
|
275
314
|
if ast_node is not None:
|
|
276
315
|
Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
|
|
277
316
|
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
|
|
278
|
-
Validator.check_value_type("func", func, [ScopedValue, str], "Node")
|
|
279
317
|
if args is not None:
|
|
280
318
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
281
319
|
if kwargs is not None:
|
|
282
320
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
283
|
-
cls_name = type(op).__name__
|
|
284
|
-
|
|
285
321
|
if args is None:
|
|
286
322
|
args = []
|
|
287
323
|
if kwargs is None:
|
|
288
324
|
kwargs = {}
|
|
289
|
-
|
|
290
|
-
func = ScopedValue.create_naming_value(func)
|
|
325
|
+
Validator.check_value_type("node_name", node_name, [str], "Node")
|
|
291
326
|
new_targets = Node._handle_targets(targets)
|
|
292
|
-
if
|
|
293
|
-
|
|
327
|
+
if isinstance(node_name, str):
|
|
328
|
+
func_name = ScopedValue.create_naming_value(node_name)
|
|
329
|
+
else:
|
|
330
|
+
func_name = node_name
|
|
331
|
+
if is_sub_net and is_subtree(op):
|
|
332
|
+
from ..symbol_tree_builder import SymbolTreeBuilder
|
|
294
333
|
stb = SymbolTreeBuilder(op)
|
|
295
334
|
stree = stb.build()
|
|
296
335
|
replacer = AstReplacer(stree.get_class_ast())
|
|
297
336
|
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
298
|
-
return TreeNode.create_tree_node(stree, ast_node, new_targets,
|
|
337
|
+
return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
|
|
338
|
+
|
|
339
|
+
return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
343
|
+
func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
344
|
+
node_name: str = ""):
|
|
345
|
+
"""
|
|
346
|
+
Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
347
|
+
A `CallCell` node represents an invoking to cell-op.
|
|
348
|
+
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
352
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
353
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
354
|
+
func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
355
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
356
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
357
|
+
class.
|
|
358
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
359
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
360
|
+
"""
|
|
299
361
|
|
|
300
|
-
|
|
362
|
+
if not isinstance(op, (Cell, Primitive)):
|
|
363
|
+
raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
|
|
364
|
+
if isinstance(op, Cell):
|
|
365
|
+
node_type = NodeType.CallCell
|
|
366
|
+
else:
|
|
367
|
+
node_type = NodeType.CallPrimitive
|
|
368
|
+
return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
|
|
301
369
|
|
|
302
370
|
@staticmethod
|
|
303
371
|
def _get_construct_arg_names(parameters):
|
|
@@ -506,21 +574,23 @@ class Node:
|
|
|
506
574
|
"""
|
|
507
575
|
return self._next
|
|
508
576
|
|
|
509
|
-
def
|
|
577
|
+
def set_prev(self, node: 'Node'):
|
|
510
578
|
"""
|
|
511
|
-
|
|
579
|
+
Set previous node of current node.
|
|
512
580
|
|
|
513
581
|
Args:
|
|
514
|
-
node (
|
|
582
|
+
node (Node): Node to be set as previous node of current node.
|
|
583
|
+
"""
|
|
584
|
+
self._prev = node
|
|
515
585
|
|
|
516
|
-
|
|
517
|
-
A bool.
|
|
586
|
+
def set_next(self, node: 'Node'):
|
|
518
587
|
"""
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
588
|
+
Set next node of current node.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
node (Node): Node to be set as next node of current node.
|
|
592
|
+
"""
|
|
593
|
+
self._next = node
|
|
524
594
|
|
|
525
595
|
def get_ast(self) -> Optional[ast.AST]:
|
|
526
596
|
"""
|
|
@@ -550,16 +620,24 @@ class Node:
|
|
|
550
620
|
"""Set the symbol tree to which node belongs."""
|
|
551
621
|
self._belong_tree = symbol_tree
|
|
552
622
|
|
|
623
|
+
def get_node_manager(self):
|
|
624
|
+
"""Get the NodeManager current node belongs to."""
|
|
625
|
+
return self._node_manager
|
|
626
|
+
|
|
627
|
+
def set_node_manager(self, node_manager):
|
|
628
|
+
"""Set NodeManager current node belongs."""
|
|
629
|
+
self._node_manager = node_manager
|
|
630
|
+
|
|
553
631
|
def isolate(self):
|
|
554
632
|
"""Link prev node to next node and isolate node from source code order list."""
|
|
555
|
-
origin_prev: Optional[Node] = self.
|
|
556
|
-
origin_next: Optional[Node] = self.
|
|
633
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
634
|
+
origin_next: Optional[Node] = self.get_next()
|
|
557
635
|
if origin_prev is not None:
|
|
558
|
-
origin_prev.
|
|
636
|
+
origin_prev.set_next(origin_next)
|
|
559
637
|
if origin_next is not None:
|
|
560
|
-
origin_next.
|
|
561
|
-
self.
|
|
562
|
-
self.
|
|
638
|
+
origin_next.set_prev(origin_prev)
|
|
639
|
+
self.set_prev(None)
|
|
640
|
+
self.set_next(None)
|
|
563
641
|
|
|
564
642
|
def insert_before(self, node: 'Node'):
|
|
565
643
|
"""
|
|
@@ -569,12 +647,12 @@ class Node:
|
|
|
569
647
|
node (Node): An instance of node to be inserted in.
|
|
570
648
|
"""
|
|
571
649
|
node.isolate()
|
|
572
|
-
origin_prev: Optional[Node] = self.
|
|
650
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
573
651
|
if origin_prev is not None:
|
|
574
|
-
origin_prev.
|
|
575
|
-
node.
|
|
576
|
-
node.
|
|
577
|
-
self.
|
|
652
|
+
origin_prev.set_next(node)
|
|
653
|
+
node.set_prev(origin_prev)
|
|
654
|
+
node.set_next(self)
|
|
655
|
+
self.set_prev(node)
|
|
578
656
|
|
|
579
657
|
def insert_after(self, node: 'Node'):
|
|
580
658
|
"""
|
|
@@ -584,31 +662,26 @@ class Node:
|
|
|
584
662
|
node (Node): An instance of node to be inserted in.
|
|
585
663
|
"""
|
|
586
664
|
node.isolate()
|
|
587
|
-
origin_next: Optional[Node] = self.
|
|
588
|
-
self.
|
|
589
|
-
node.
|
|
590
|
-
node.
|
|
665
|
+
origin_next: Optional[Node] = self.get_next()
|
|
666
|
+
self.set_next(node)
|
|
667
|
+
node.set_prev(self)
|
|
668
|
+
node.set_next(origin_next)
|
|
591
669
|
if origin_next is not None:
|
|
592
|
-
origin_next.
|
|
670
|
+
origin_next.set_prev(node)
|
|
593
671
|
|
|
594
672
|
def get_inputs(self) -> ['Node']:
|
|
595
673
|
"""
|
|
596
|
-
|
|
674
|
+
Get input nodes of current node in topological order.
|
|
597
675
|
|
|
598
676
|
Returns:
|
|
599
677
|
A list of instances of Node as input nodes.
|
|
600
678
|
"""
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
Args:
|
|
609
|
-
inputs (list[Node]): A list of instances of Node as new input nodes.
|
|
610
|
-
"""
|
|
611
|
-
self._inputs = inputs
|
|
679
|
+
inputs = []
|
|
680
|
+
for arg_provider in self.get_arg_providers().values():
|
|
681
|
+
if not arg_provider:
|
|
682
|
+
continue
|
|
683
|
+
inputs.append(arg_provider[0])
|
|
684
|
+
return inputs
|
|
612
685
|
|
|
613
686
|
def get_targets(self) -> [ScopedValue]:
|
|
614
687
|
"""
|
|
@@ -654,26 +727,26 @@ class Node:
|
|
|
654
727
|
NodeType.MathOps):
|
|
655
728
|
self._sync_assign_targets_to_ast()
|
|
656
729
|
|
|
657
|
-
def
|
|
730
|
+
def get_func_name(self) -> ScopedValue:
|
|
658
731
|
"""
|
|
659
|
-
Getter of `
|
|
732
|
+
Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
660
733
|
|
|
661
734
|
Returns:
|
|
662
735
|
An instance of ScopedValue.
|
|
663
736
|
"""
|
|
664
|
-
return self.
|
|
737
|
+
return self._func_name
|
|
665
738
|
|
|
666
|
-
def
|
|
739
|
+
def set_func_name(self, func_name: ScopedValue):
|
|
667
740
|
"""
|
|
668
|
-
Setter of `
|
|
741
|
+
Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
669
742
|
|
|
670
743
|
Note:
|
|
671
|
-
When `
|
|
744
|
+
When `_func_name` is updated, corresponding ast node would be updated also.
|
|
672
745
|
|
|
673
746
|
Args:
|
|
674
747
|
func (ScopedValue): An instance of ScopedValue as new func.
|
|
675
748
|
"""
|
|
676
|
-
self.
|
|
749
|
+
self._func_name = func_name
|
|
677
750
|
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
|
|
678
751
|
self._sync_assign_func_to_ast()
|
|
679
752
|
|
|
@@ -750,11 +823,11 @@ class Node:
|
|
|
750
823
|
Validator.check_value_type("node", node, [Node], "Node")
|
|
751
824
|
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
752
825
|
if out_idx is None:
|
|
753
|
-
if len(node.
|
|
826
|
+
if len(node.get_targets()) != 1:
|
|
754
827
|
raise RuntimeError("node should has one output when out_idx is not provided")
|
|
755
828
|
out_idx = 0
|
|
756
|
-
Validator.check_int_range(out_idx, 0, len(node.
|
|
757
|
-
new_arg = node.
|
|
829
|
+
Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
|
|
830
|
+
new_arg = node.get_targets()[out_idx]
|
|
758
831
|
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
|
759
832
|
self._sync_arg()
|
|
760
833
|
|
|
@@ -943,6 +1016,66 @@ class Node:
|
|
|
943
1016
|
"""
|
|
944
1017
|
return self._attribute.get(key)
|
|
945
1018
|
|
|
1019
|
+
def get_arg_providers(self) -> dict:
|
|
1020
|
+
"""
|
|
1021
|
+
Getter of _arg_providers.
|
|
1022
|
+
|
|
1023
|
+
Return:
|
|
1024
|
+
dict, key is type of int indicating the index of args, and value is type of tuple, which includes
|
|
1025
|
+
the node and the index of node's targets who provides the argument.
|
|
1026
|
+
"""
|
|
1027
|
+
return self._arg_providers
|
|
1028
|
+
|
|
1029
|
+
def set_arg_providers(self, index: int, provider: tuple):
|
|
1030
|
+
"""
|
|
1031
|
+
Setter of _arg_providers.
|
|
1032
|
+
|
|
1033
|
+
Args:
|
|
1034
|
+
index (int): Indicating provider of which argument need to be set.
|
|
1035
|
+
provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
|
|
1036
|
+
"""
|
|
1037
|
+
self._arg_providers[index] = provider
|
|
1038
|
+
|
|
1039
|
+
def get_target_users(self, index=-1) -> Union[dict, list]:
|
|
1040
|
+
"""
|
|
1041
|
+
Getter of _target_users.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
|
|
1045
|
+
be returned.
|
|
1046
|
+
|
|
1047
|
+
Return:
|
|
1048
|
+
Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
|
|
1049
|
+
The type of elements in list is tuple, which includes the user node and the index of node's arguments
|
|
1050
|
+
who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
|
|
1051
|
+
value is list of users of corresponding target.
|
|
1052
|
+
"""
|
|
1053
|
+
if index == -1:
|
|
1054
|
+
return self._target_users
|
|
1055
|
+
if index not in self._target_users.keys():
|
|
1056
|
+
self._target_users[index] = []
|
|
1057
|
+
return self._target_users.get(index, None)
|
|
1058
|
+
|
|
1059
|
+
def append_target_users(self, index: int, provider: tuple):
|
|
1060
|
+
"""
|
|
1061
|
+
Setter of _target_users.
|
|
1062
|
+
|
|
1063
|
+
Args:
|
|
1064
|
+
index (int): Indicating users of which target need to be append.
|
|
1065
|
+
provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
|
|
1066
|
+
|
|
1067
|
+
"""
|
|
1068
|
+
if index not in self._target_users.keys():
|
|
1069
|
+
self._target_users[index] = []
|
|
1070
|
+
self._target_users.get(index).append(provider)
|
|
1071
|
+
|
|
1072
|
+
def update_ast_node(self) -> ast.AST:
|
|
1073
|
+
"""Update node's ast_node by current targets, func_name, args and kwargs."""
|
|
1074
|
+
ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
|
|
1075
|
+
self.get_args(), self.get_kwargs())
|
|
1076
|
+
self.set_ast(ast_assign)
|
|
1077
|
+
return ast_assign
|
|
1078
|
+
|
|
946
1079
|
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
|
947
1080
|
"""
|
|
948
1081
|
Merge args and kwargs to normalized args.
|
|
@@ -983,6 +1116,10 @@ class Node:
|
|
|
983
1116
|
self._normalized_args_keys.append(arg_key)
|
|
984
1117
|
return normalized_args
|
|
985
1118
|
|
|
1119
|
+
##########################################################################################################
|
|
1120
|
+
# Synchronize rewrite node args to ast node
|
|
1121
|
+
##########################################################################################################
|
|
1122
|
+
|
|
986
1123
|
def _sync_assign_func_to_ast(self):
|
|
987
1124
|
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
|
988
1125
|
if self._ast_node is None:
|
|
@@ -994,20 +1131,21 @@ class Node:
|
|
|
994
1131
|
if not isinstance(call_ast, ast.Call):
|
|
995
1132
|
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
|
996
1133
|
func_ast = call_ast.func
|
|
997
|
-
if not self.
|
|
1134
|
+
if not self._func_name.value:
|
|
998
1135
|
if isinstance(func_ast, ast.Name):
|
|
999
|
-
func_ast.id = self.
|
|
1136
|
+
func_ast.id = self._func_name.value
|
|
1000
1137
|
else:
|
|
1001
|
-
call_ast.func = ast.Name(self.
|
|
1138
|
+
call_ast.func = ast.Name(self._func_name.value, ast.Store())
|
|
1002
1139
|
else:
|
|
1003
1140
|
if isinstance(func_ast, ast.Attribute):
|
|
1004
1141
|
func_value = func_ast.value
|
|
1005
1142
|
if not isinstance(func_value, ast.Name):
|
|
1006
1143
|
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
|
|
1007
|
-
func_value.id = self.
|
|
1008
|
-
func_ast.attr = self.
|
|
1144
|
+
func_value.id = self._func_name.scope
|
|
1145
|
+
func_ast.attr = self._func_name.value
|
|
1009
1146
|
else:
|
|
1010
|
-
call_ast.func = ast.Attribute(ast.Name(self.
|
|
1147
|
+
call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
|
|
1148
|
+
self._func_name.value, ast.Store())
|
|
1011
1149
|
ast.fix_missing_locations(assign_ast)
|
|
1012
1150
|
|
|
1013
1151
|
def _sync_assign_targets_to_ast(self):
|
|
@@ -1023,7 +1161,7 @@ class Node:
|
|
|
1023
1161
|
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
|
1024
1162
|
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
|
1025
1163
|
raise RuntimeError("self._targets should have targets_ast same length")
|
|
1026
|
-
for i in
|
|
1164
|
+
for i, _ in enumerate(self._targets):
|
|
1027
1165
|
target = self._targets[i]
|
|
1028
1166
|
target_ast = targets_ast[0]
|
|
1029
1167
|
if isinstance(target_ast, ast.Name):
|
|
@@ -1043,7 +1181,7 @@ class Node:
|
|
|
1043
1181
|
return
|
|
1044
1182
|
assign_ast = self._ast_node
|
|
1045
1183
|
if not isinstance(assign_ast, ast.Assign):
|
|
1046
|
-
raise TypeError("assign_ast should be ast.Assign, got:
|
|
1184
|
+
raise TypeError(f"assign_ast should be ast.Assign, got: {type(assign_ast)}")
|
|
1047
1185
|
assign_value = assign_ast.value
|
|
1048
1186
|
if not isinstance(assign_value, ast.Call):
|
|
1049
1187
|
return
|
|
@@ -1094,23 +1232,31 @@ class Node:
|
|
|
1094
1232
|
if len(self._normalized_args_keys) != 1:
|
|
1095
1233
|
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1096
1234
|
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1097
|
-
if arg.type
|
|
1098
|
-
raise RuntimeError("arg should be an
|
|
1235
|
+
if arg.type != ValueType.ConstantValue:
|
|
1236
|
+
raise RuntimeError("arg should be an ConstantValue")
|
|
1099
1237
|
if arg.scope != "":
|
|
1100
1238
|
raise RuntimeError("arg.scope should be empty")
|
|
1101
1239
|
assign_value.value = arg.value
|
|
1102
1240
|
|
|
1103
1241
|
def _sync_call_method_args_to_ast(self):
|
|
1104
|
-
"""
|
|
1242
|
+
"""
|
|
1243
|
+
Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
|
|
1244
|
+
|
|
1245
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1246
|
+
- ast.Tuple
|
|
1247
|
+
- ast.Name
|
|
1248
|
+
- ast.ast.Attribute
|
|
1249
|
+
- ...
|
|
1250
|
+
"""
|
|
1105
1251
|
if self._ast_node is None:
|
|
1106
1252
|
return
|
|
1107
1253
|
assign_ast = self._ast_node
|
|
1108
1254
|
if not isinstance(assign_ast, ast.Assign):
|
|
1109
1255
|
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1110
1256
|
assign_value = assign_ast.value
|
|
1111
|
-
if self.
|
|
1257
|
+
if self._func_name == PASS_THROUGH_METHOD:
|
|
1112
1258
|
self._sync_call_pass_through_method_args_to_ast(assign_value)
|
|
1113
|
-
elif self.
|
|
1259
|
+
elif self._func_name.value == "tuple":
|
|
1114
1260
|
tuple_ast: ast.Tuple = assign_value
|
|
1115
1261
|
if len(self._normalized_args_keys) != len(tuple_ast.elts):
|
|
1116
1262
|
raise RuntimeError("size of self._normalized_args_keys should be equal to size of elements of tuple")
|
|
@@ -1130,10 +1276,16 @@ class Node:
|
|
|
1130
1276
|
else:
|
|
1131
1277
|
raise RuntimeError("Only support constant or symbol in tuple now")
|
|
1132
1278
|
else:
|
|
1133
|
-
raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self.
|
|
1279
|
+
raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func_name.value)
|
|
1134
1280
|
|
|
1135
1281
|
def _sync_return_node_to_ast(self):
|
|
1136
|
-
"""
|
|
1282
|
+
"""
|
|
1283
|
+
Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
|
|
1284
|
+
|
|
1285
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1286
|
+
- ast.Name
|
|
1287
|
+
- ast.Tuple
|
|
1288
|
+
"""
|
|
1137
1289
|
if self._ast_node is None:
|
|
1138
1290
|
return
|
|
1139
1291
|
return_ast = self._ast_node
|
|
@@ -1195,7 +1347,7 @@ class Node:
|
|
|
1195
1347
|
|
|
1196
1348
|
def _sync_arg(self):
|
|
1197
1349
|
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1198
|
-
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree
|
|
1350
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
|
|
1199
1351
|
NodeType.CellContainer, NodeType.CallFunction):
|
|
1200
1352
|
self._sync_call_cell_args_to_ast()
|
|
1201
1353
|
elif self._node_type == NodeType.Output:
|
|
@@ -1206,15 +1358,18 @@ class Node:
|
|
|
1206
1358
|
self._sync_mathops_node_args_to_ast()
|
|
1207
1359
|
|
|
1208
1360
|
|
|
1361
|
+
##########################################################################################################
|
|
1362
|
+
# Child classes
|
|
1363
|
+
##########################################################################################################
|
|
1364
|
+
|
|
1209
1365
|
class TreeNode(Node):
|
|
1210
1366
|
"""Tree type Node who holds a handler of SymbolTree."""
|
|
1211
1367
|
|
|
1212
1368
|
def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1213
1369
|
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1214
1370
|
"""
|
|
1215
|
-
Constructor of
|
|
1216
|
-
as `
|
|
1217
|
-
`create_output_node`, etc. rather than invoking constructor of Node directly.
|
|
1371
|
+
Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
|
|
1372
|
+
TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
|
|
1218
1373
|
|
|
1219
1374
|
Args:
|
|
1220
1375
|
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
@@ -1233,8 +1388,9 @@ class TreeNode(Node):
|
|
|
1233
1388
|
self.symbol_tree = tree
|
|
1234
1389
|
|
|
1235
1390
|
@classmethod
|
|
1236
|
-
def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
|
|
1237
|
-
args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
1391
|
+
def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
|
|
1392
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
1393
|
+
name: str = "", instance=None):
|
|
1238
1394
|
"""
|
|
1239
1395
|
Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
|
|
1240
1396
|
to sub-network.
|
|
@@ -1243,104 +1399,14 @@ class TreeNode(Node):
|
|
|
1243
1399
|
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
1244
1400
|
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1245
1401
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1246
|
-
|
|
1402
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1247
1403
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1248
1404
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1249
1405
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1250
1406
|
Name of node also used as field name in network class.
|
|
1251
1407
|
instance: Object in network corresponding to this node.
|
|
1252
1408
|
"""
|
|
1253
|
-
|
|
1254
|
-
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
1255
|
-
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
1256
1409
|
new_targets = Node._handle_targets(targets)
|
|
1257
|
-
if isinstance(
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
|
|
1261
|
-
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
class CellContainer(Node):
|
|
1265
|
-
""" Container for saving cell-objects node. """
|
|
1266
|
-
class _Visitor():
|
|
1267
|
-
""" A iterator of CellContainer nodes. """
|
|
1268
|
-
def __init__(self, cellcontainer):
|
|
1269
|
-
self._cellcontainer = cellcontainer
|
|
1270
|
-
|
|
1271
|
-
def __len__(self):
|
|
1272
|
-
""" Get the number of nodes. """
|
|
1273
|
-
return self._cellcontainer.node_count
|
|
1274
|
-
|
|
1275
|
-
def __iter__(self):
|
|
1276
|
-
"""Create an iterator over the CellContainer."""
|
|
1277
|
-
count = len(self._cellcontainer.node_list)
|
|
1278
|
-
i = 0
|
|
1279
|
-
while i < count:
|
|
1280
|
-
curr = self._cellcontainer.node_list[i]
|
|
1281
|
-
if curr.valid:
|
|
1282
|
-
yield curr
|
|
1283
|
-
i += 1
|
|
1284
|
-
|
|
1285
|
-
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1286
|
-
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1287
|
-
"""Constructor of CellContainer.
|
|
1288
|
-
|
|
1289
|
-
Args:
|
|
1290
|
-
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1291
|
-
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1292
|
-
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1293
|
-
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1294
|
-
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1295
|
-
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1296
|
-
Name of node also used as field name in network class.
|
|
1297
|
-
instance: Object in network corresponding to this node.
|
|
1298
|
-
"""
|
|
1299
|
-
if isinstance(func, str):
|
|
1300
|
-
func = ScopedValue.create_naming_value(func)
|
|
1301
|
-
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
|
|
1302
|
-
self._node_list = list()
|
|
1303
|
-
self._node_count = 0
|
|
1304
|
-
|
|
1305
|
-
@property
|
|
1306
|
-
def node_count(self):
|
|
1307
|
-
"""Number of nodes."""
|
|
1308
|
-
return len(self._node_list)
|
|
1309
|
-
|
|
1310
|
-
@property
|
|
1311
|
-
def node_list(self):
|
|
1312
|
-
""" Get node list. """
|
|
1313
|
-
return self._node_list
|
|
1314
|
-
|
|
1315
|
-
def append(self, node):
|
|
1316
|
-
""" Append new node to node list. """
|
|
1317
|
-
setattr(node, "container", self)
|
|
1318
|
-
setattr(node, "valid", True)
|
|
1319
|
-
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1320
|
-
self._node_list.append(node)
|
|
1321
|
-
# when creating a cell_container, node instance is already in SequentialCell cell_list
|
|
1322
|
-
# so here we need to write a if judgement
|
|
1323
|
-
if node.get_instance() not in self.get_instance().cell_list:
|
|
1324
|
-
self.get_instance().append(node.get_instance())
|
|
1325
|
-
|
|
1326
|
-
def erase(self, node):
|
|
1327
|
-
"""Erase node form container."""
|
|
1328
|
-
index_node = self.node_list.index(node)
|
|
1329
|
-
index_instance = self.get_instance().cell_list.index(node.get_instance())
|
|
1330
|
-
if index_node != index_instance:
|
|
1331
|
-
raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
|
|
1332
|
-
setattr(node, "valid", False)
|
|
1333
|
-
del self.get_instance()[index_node]
|
|
1334
|
-
del self._node_list[index_node]
|
|
1335
|
-
|
|
1336
|
-
def insert(self, index, node):
|
|
1337
|
-
"""Insert node into container"""
|
|
1338
|
-
self.node_list.insert(index, node)
|
|
1339
|
-
setattr(node, "container", self)
|
|
1340
|
-
setattr(node, "valid", True)
|
|
1341
|
-
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1342
|
-
self.get_instance()._insert(index, node.get_instance())
|
|
1343
|
-
|
|
1344
|
-
def nodes(self):
|
|
1345
|
-
""" Return a iterator of node."""
|
|
1346
|
-
return self._Visitor(self)
|
|
1410
|
+
if isinstance(func_name, str):
|
|
1411
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
1412
|
+
return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)
|