mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/rewrite/symbol_tree.py
CHANGED
|
@@ -21,15 +21,14 @@ import ast
|
|
|
21
21
|
import importlib
|
|
22
22
|
import types
|
|
23
23
|
import time
|
|
24
|
-
|
|
25
24
|
import astunparse
|
|
26
25
|
|
|
27
26
|
from mindspore.nn import Cell
|
|
28
27
|
from mindspore import log as logger
|
|
29
28
|
from mindspore.rewrite.ast_creator_register import ast_creator_registry
|
|
30
|
-
from .node import Node, TreeNode
|
|
29
|
+
from .node import Node, TreeNode
|
|
31
30
|
from .api.node_type import NodeType
|
|
32
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder
|
|
31
|
+
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
|
|
33
32
|
from .api.scoped_value import ScopedValue, ValueType
|
|
34
33
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
35
34
|
from .topological_manager import TopoManager
|
|
@@ -160,7 +159,6 @@ class SymbolTree(Observer, Observable):
|
|
|
160
159
|
self._topo_mgr = TopoManager()
|
|
161
160
|
self._topo_mgr.reg_observer(self)
|
|
162
161
|
|
|
163
|
-
self._global_vars: {str, object} = {origin_network_key: origin_network}
|
|
164
162
|
self._nodes: {str, Node} = {}
|
|
165
163
|
# parameters of forward method
|
|
166
164
|
self._inputs: [Node] = []
|
|
@@ -171,6 +169,10 @@ class SymbolTree(Observer, Observable):
|
|
|
171
169
|
self._class_ast: Optional[ast.ClassDef] = None
|
|
172
170
|
self._root_ast: Optional[ast.FunctionDef] = None
|
|
173
171
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
172
|
+
self._deleted_field = {}
|
|
173
|
+
self._deleted_node = []
|
|
174
|
+
self._external_func_ast = []
|
|
175
|
+
self._father_class_ast = []
|
|
174
176
|
|
|
175
177
|
# head node is always point to the first node(in source code order) of SymbolTree
|
|
176
178
|
self._head = None
|
|
@@ -263,6 +265,8 @@ class SymbolTree(Observer, Observable):
|
|
|
263
265
|
for node in stree.nodes():
|
|
264
266
|
if not isinstance(node, TreeNode):
|
|
265
267
|
continue
|
|
268
|
+
if node.symbol_tree._class_ast is None:
|
|
269
|
+
continue
|
|
266
270
|
sub_stree: SymbolTree = node.symbol_tree
|
|
267
271
|
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
|
|
268
272
|
# all modified ast.ClassDef should export to code
|
|
@@ -281,31 +285,7 @@ class SymbolTree(Observer, Observable):
|
|
|
281
285
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
282
286
|
self.add_event(Event.TopologicalChangeEvent)
|
|
283
287
|
|
|
284
|
-
def
|
|
285
|
-
"""
|
|
286
|
-
Create a ast.Assign type node.
|
|
287
|
-
|
|
288
|
-
Args:
|
|
289
|
-
targets (list): _description_
|
|
290
|
-
func_name (_type_): _description_
|
|
291
|
-
args (_type_): _description_
|
|
292
|
-
kwargs (_type_): _description_
|
|
293
|
-
|
|
294
|
-
Returns:
|
|
295
|
-
_type_: _description_
|
|
296
|
-
"""
|
|
297
|
-
# create targets
|
|
298
|
-
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
299
|
-
# create call
|
|
300
|
-
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
301
|
-
ast_args = ast_creator_registry.get("Args")(args)
|
|
302
|
-
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
303
|
-
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
304
|
-
# create assign
|
|
305
|
-
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
306
|
-
return ast_node
|
|
307
|
-
|
|
308
|
-
def create_call_function(self, func, targets, args, kwargs):
|
|
288
|
+
def _create_call_function(self, func, targets, args, kwargs):
|
|
309
289
|
"""
|
|
310
290
|
Create a Node object and generate the execution code to insert into the source code.
|
|
311
291
|
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
@@ -345,6 +325,30 @@ class SymbolTree(Observer, Observable):
|
|
|
345
325
|
call_kwargs)
|
|
346
326
|
return node
|
|
347
327
|
|
|
328
|
+
def create_assign_node(self, targets, func_name, args, kwargs):
|
|
329
|
+
"""
|
|
330
|
+
Create a ast.Assign type node.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
targets (list): _description_
|
|
334
|
+
func_name (_type_): _description_
|
|
335
|
+
args (_type_): _description_
|
|
336
|
+
kwargs (_type_): _description_
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
_type_: _description_
|
|
340
|
+
"""
|
|
341
|
+
# create targets
|
|
342
|
+
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
343
|
+
# create call
|
|
344
|
+
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
345
|
+
ast_args = ast_creator_registry.get("Args")(args)
|
|
346
|
+
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
347
|
+
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
348
|
+
# create assign
|
|
349
|
+
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
350
|
+
return ast_node
|
|
351
|
+
|
|
348
352
|
def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
|
|
349
353
|
'''
|
|
350
354
|
Instantiate an instance of node whose type is `CallFunction`.
|
|
@@ -458,12 +462,6 @@ class SymbolTree(Observer, Observable):
|
|
|
458
462
|
self._init_func_ast = ast_node
|
|
459
463
|
|
|
460
464
|
def get_inputs(self):
|
|
461
|
-
"""
|
|
462
|
-
Getter of `_inputs` which represents parameters of current forward method.
|
|
463
|
-
|
|
464
|
-
Returns:
|
|
465
|
-
A list of instance of Node whose node_type is NodeType.Input as input nodes.
|
|
466
|
-
"""
|
|
467
465
|
return self._inputs
|
|
468
466
|
|
|
469
467
|
def get_head_node(self):
|
|
@@ -484,17 +482,6 @@ class SymbolTree(Observer, Observable):
|
|
|
484
482
|
"""
|
|
485
483
|
return self._origin_network
|
|
486
484
|
|
|
487
|
-
def get_global_vars(self):
|
|
488
|
-
"""Get global variables."""
|
|
489
|
-
return self._global_vars
|
|
490
|
-
|
|
491
|
-
def add_global_vars(self, key: str, value):
|
|
492
|
-
"""Add global variables."""
|
|
493
|
-
if self._global_vars.get(key) is not None:
|
|
494
|
-
logger.info(f"The key '{key}' is duplicated")
|
|
495
|
-
return
|
|
496
|
-
self._global_vars[key] = value
|
|
497
|
-
|
|
498
485
|
def get_nodes_dict(self):
|
|
499
486
|
"""Get dict of nodes"""
|
|
500
487
|
return self._nodes
|
|
@@ -614,7 +601,6 @@ class SymbolTree(Observer, Observable):
|
|
|
614
601
|
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
|
615
602
|
SymbolTree.
|
|
616
603
|
"""
|
|
617
|
-
|
|
618
604
|
node = self._get_real_node(node_or_name)
|
|
619
605
|
if node is None:
|
|
620
606
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
@@ -653,7 +639,12 @@ class SymbolTree(Observer, Observable):
|
|
|
653
639
|
RuntimeError: If 'position' is not in current SymbolTree.
|
|
654
640
|
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
655
641
|
"""
|
|
656
|
-
|
|
642
|
+
if position is not None and hasattr(position.node, "container"):
|
|
643
|
+
cellcontainer = getattr(position.node, "container")
|
|
644
|
+
index = cellcontainer.node_list.index(position.node)
|
|
645
|
+
index = index if position.before_node else index + 1
|
|
646
|
+
cellcontainer.insert(index, node)
|
|
647
|
+
return node
|
|
657
648
|
# if position in current SymbolTree
|
|
658
649
|
if position is not None and position.symbol_tree is not self:
|
|
659
650
|
raise RuntimeError("Position is not in current SymbolTree:", position)
|
|
@@ -678,37 +669,7 @@ class SymbolTree(Observer, Observable):
|
|
|
678
669
|
self._node_visitor.append_node(node)
|
|
679
670
|
# update init-function-ast and construct-function-ast
|
|
680
671
|
if insert_to_ast:
|
|
681
|
-
|
|
682
|
-
node_ast = node.get_ast()
|
|
683
|
-
if not isinstance(node_ast, ast.Assign):
|
|
684
|
-
raise RuntimeError("Only support insert cell op now")
|
|
685
|
-
if isinstance(node, TreeNode):
|
|
686
|
-
global_vars_key = node.get_name() + "_args"
|
|
687
|
-
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
|
|
688
|
-
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
|
|
689
|
-
[ScopedValue.create_variable_value(global_vars_key)])
|
|
690
|
-
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
691
|
-
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
692
|
-
|
|
693
|
-
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
694
|
-
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
695
|
-
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
696
|
-
|
|
697
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
698
|
-
None if position is None else position.node.get_ast(),
|
|
699
|
-
position.before_node)
|
|
700
|
-
sub_stree: SymbolTree = node.symbol_tree
|
|
701
|
-
from .symbol_tree_builder import SymbolTreeBuilder
|
|
702
|
-
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
703
|
-
else:
|
|
704
|
-
AstModifier.insert_assign_to_function(self._init_func_ast,
|
|
705
|
-
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
|
706
|
-
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
|
707
|
-
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
|
708
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
709
|
-
None if position is None else position.node.get_ast(),
|
|
710
|
-
position.before_node)
|
|
711
|
-
self._global_vars[node_name] = node.get_instance()
|
|
672
|
+
self._insert_to_ast_while_insert_node(node, position)
|
|
712
673
|
return node
|
|
713
674
|
|
|
714
675
|
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
|
|
@@ -807,8 +768,9 @@ class SymbolTree(Observer, Observable):
|
|
|
807
768
|
Returns:
|
|
808
769
|
An instance of python node which has been appended to SymbolTree.
|
|
809
770
|
"""
|
|
810
|
-
logger.
|
|
771
|
+
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
811
772
|
node_name = self._node_name_namer.get_name(type(ast_node).__name__)
|
|
773
|
+
self._update_names_for_unique(ast_node)
|
|
812
774
|
node = Node.create_python_node(ast_node, node_name)
|
|
813
775
|
self._insert_node(Position.create(self, self._tail, False), node)
|
|
814
776
|
return node
|
|
@@ -851,6 +813,10 @@ class SymbolTree(Observer, Observable):
|
|
|
851
813
|
node = self._get_real_node(node_or_name)
|
|
852
814
|
if node is None:
|
|
853
815
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
816
|
+
if hasattr(node, "container"):
|
|
817
|
+
cellcontainer = getattr(node, "container")
|
|
818
|
+
cellcontainer.erase(node)
|
|
819
|
+
return node
|
|
854
820
|
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
855
821
|
if not ret:
|
|
856
822
|
raise RuntimeError("node not in function ast tree.")
|
|
@@ -860,6 +826,7 @@ class SymbolTree(Observer, Observable):
|
|
|
860
826
|
value.isolate()
|
|
861
827
|
break
|
|
862
828
|
self._topo_mgr.on_erase_node(node)
|
|
829
|
+
self._deleted_node.append(node.get_name())
|
|
863
830
|
return node
|
|
864
831
|
|
|
865
832
|
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
@@ -884,6 +851,9 @@ class SymbolTree(Observer, Observable):
|
|
|
884
851
|
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
885
852
|
"""
|
|
886
853
|
|
|
854
|
+
if hasattr(old_node, "container"):
|
|
855
|
+
self._replace_container_node(old_node, new_nodes)
|
|
856
|
+
return new_nodes[0]
|
|
887
857
|
real_old_node = self._get_real_node(old_node)
|
|
888
858
|
if real_old_node is None:
|
|
889
859
|
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
@@ -981,6 +951,13 @@ class SymbolTree(Observer, Observable):
|
|
|
981
951
|
dump_st = SymbolTreeDumper(self)
|
|
982
952
|
dump_st.dump()
|
|
983
953
|
|
|
954
|
+
def update_module_ast(self):
|
|
955
|
+
for node in self._external_func_ast:
|
|
956
|
+
self._module_ast.body.append(node)
|
|
957
|
+
for node in self._father_class_ast:
|
|
958
|
+
index = self._module_ast.body.index(self._class_ast)
|
|
959
|
+
self._module_ast.body.insert(index, node)
|
|
960
|
+
|
|
984
961
|
def get_code(self) -> str:
|
|
985
962
|
"""
|
|
986
963
|
Get source code of modified network.
|
|
@@ -992,6 +969,7 @@ class SymbolTree(Observer, Observable):
|
|
|
992
969
|
if self._init_func_ast:
|
|
993
970
|
self._remove_unused_field()
|
|
994
971
|
self._remove_duplicated_import()
|
|
972
|
+
self.update_module_ast()
|
|
995
973
|
ast.fix_missing_locations(self._module_ast)
|
|
996
974
|
# Find all ast.ClassDef which can be export to code
|
|
997
975
|
# Replace duplicated ast.ClassDef reference in main-ClassDef
|
|
@@ -1026,21 +1004,20 @@ class SymbolTree(Observer, Observable):
|
|
|
1026
1004
|
A network object.
|
|
1027
1005
|
"""
|
|
1028
1006
|
cls = self._get_cls_through_file()
|
|
1029
|
-
|
|
1007
|
+
new_net = cls(self._origin_network)
|
|
1008
|
+
self._merge_origin_property(new_net)
|
|
1009
|
+
return new_net
|
|
1030
1010
|
|
|
1031
1011
|
def set_saved_file_name(self, file_name: str):
|
|
1032
|
-
"""Sets the filename used to save the network."""
|
|
1033
1012
|
if file_name.endswith(".py"):
|
|
1034
1013
|
self._saved_file_name = file_name
|
|
1035
1014
|
else:
|
|
1036
1015
|
self._saved_file_name = file_name + ".py"
|
|
1037
1016
|
|
|
1038
1017
|
def get_saved_file_name(self):
|
|
1039
|
-
"""Gets the filename used to save the network."""
|
|
1040
1018
|
return self._saved_file_name
|
|
1041
1019
|
|
|
1042
1020
|
def save_network_to_file(self):
|
|
1043
|
-
"""Save the modified network to a file."""
|
|
1044
1021
|
abs_path = os.path.abspath(self._saved_file_name)
|
|
1045
1022
|
if os.path.isfile(abs_path):
|
|
1046
1023
|
os.remove(abs_path)
|
|
@@ -1049,6 +1026,58 @@ class SymbolTree(Observer, Observable):
|
|
|
1049
1026
|
f.write(source.encode('utf-8'))
|
|
1050
1027
|
f.flush()
|
|
1051
1028
|
|
|
1029
|
+
def update_scope_for_unique(self, node: Union[ast.Attribute, ast.Call, ast.Subscript]):
|
|
1030
|
+
""" Update scope of ast node because of unique-ing of targets of other nodes. """
|
|
1031
|
+
if isinstance(node, ast.Call):
|
|
1032
|
+
self.update_scope_for_unique(node.func)
|
|
1033
|
+
return
|
|
1034
|
+
if not isinstance(node, (ast.Attribute, ast.Subscript)):
|
|
1035
|
+
logger.warning(f"Cannot update node {astunparse.unparse(node)} for unique, type of node should "
|
|
1036
|
+
f"be one of (ast.Attribute, ast.Subscript).")
|
|
1037
|
+
return
|
|
1038
|
+
scope = node.value
|
|
1039
|
+
if not isinstance(scope, ast.Name):
|
|
1040
|
+
self.update_scope_for_unique(scope)
|
|
1041
|
+
return
|
|
1042
|
+
scope_name = scope.id
|
|
1043
|
+
scope_name_unique = self._target_namer.get_real_arg(scope_name)
|
|
1044
|
+
scope.id = scope_name_unique
|
|
1045
|
+
|
|
1046
|
+
def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
|
|
1047
|
+
""" insert_to_ast_while_insert_node. """
|
|
1048
|
+
node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
|
|
1049
|
+
node_ast = node.get_ast()
|
|
1050
|
+
if not isinstance(node_ast, ast.Assign):
|
|
1051
|
+
raise RuntimeError("Only support insert cell op now")
|
|
1052
|
+
if isinstance(node, TreeNode):
|
|
1053
|
+
setattr(self._origin_network, node.get_name(), node.get_instance())
|
|
1054
|
+
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
1055
|
+
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
1056
|
+
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
1057
|
+
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
1058
|
+
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
1059
|
+
|
|
1060
|
+
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
1061
|
+
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
1062
|
+
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
1063
|
+
|
|
1064
|
+
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1065
|
+
None if position is None else position.node.get_ast(),
|
|
1066
|
+
position.before_node)
|
|
1067
|
+
sub_stree: SymbolTree = node.symbol_tree
|
|
1068
|
+
from .symbol_tree_builder import SymbolTreeBuilder
|
|
1069
|
+
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
1070
|
+
else:
|
|
1071
|
+
AstModifier.insert_assign_to_function(self._init_func_ast,
|
|
1072
|
+
targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
|
|
1073
|
+
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
1074
|
+
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
1075
|
+
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
1076
|
+
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1077
|
+
None if position is None else position.node.get_ast(),
|
|
1078
|
+
position.before_node)
|
|
1079
|
+
setattr(self._origin_network, node.get_name(), node.get_instance())
|
|
1080
|
+
|
|
1052
1081
|
def _remove_unused_import(self):
|
|
1053
1082
|
"""remove unused import in self._module_ast"""
|
|
1054
1083
|
str_checker = StrChecker(self._module_ast)
|
|
@@ -1070,49 +1099,43 @@ class SymbolTree(Observer, Observable):
|
|
|
1070
1099
|
else:
|
|
1071
1100
|
body.names.remove(alias)
|
|
1072
1101
|
|
|
1102
|
+
def _replace_container_node(self, old_node, new_nodes):
|
|
1103
|
+
cellcontainer = getattr(old_node, "container")
|
|
1104
|
+
index = cellcontainer.node_list.index(old_node)
|
|
1105
|
+
for n in reversed(new_nodes):
|
|
1106
|
+
cellcontainer.insert(index, n)
|
|
1107
|
+
index = cellcontainer.node_list.index(old_node)
|
|
1108
|
+
cellcontainer.erase(old_node)
|
|
1109
|
+
|
|
1073
1110
|
def _filter_out_to_delete_field(self, to_delete_field):
|
|
1074
1111
|
"""filter out used field from `to_delete_field`"""
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
to_delete_field.pop("_handler")
|
|
1078
|
-
# filter field used in node of construct
|
|
1079
|
-
for node in self._nodes.values():
|
|
1080
|
-
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
|
1081
|
-
func: ScopedValue = node.get_func()
|
|
1082
|
-
if func.scope == "self" and to_delete_field.get(func.value):
|
|
1083
|
-
to_delete_field.pop(func.value)
|
|
1084
|
-
if node.get_node_type() == NodeType.CallMethod and node.get_func() == PASS_THROUGH_METHOD:
|
|
1085
|
-
var_name = node.get_args()[0].value
|
|
1086
|
-
if to_delete_field.get(var_name):
|
|
1087
|
-
to_delete_field.pop(var_name)
|
|
1088
|
-
# filter field used in test-of-if of construct function
|
|
1089
|
-
for body in self._root_ast.body:
|
|
1090
|
-
if not isinstance(body, ast.If):
|
|
1112
|
+
for func_def in self._class_ast.body:
|
|
1113
|
+
if not isinstance(func_def, ast.FunctionDef):
|
|
1091
1114
|
continue
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1115
|
+
if func_def.name != "__init__":
|
|
1116
|
+
to_delete_to_delete_keys = []
|
|
1117
|
+
property_checker = CheckPropertyIsUsed(func_def)
|
|
1118
|
+
for key, _ in self._deleted_field.items():
|
|
1119
|
+
if property_checker.check("self", key):
|
|
1120
|
+
to_delete_to_delete_keys.append(key)
|
|
1121
|
+
property_checker = CheckPropertyIsUsed(func_def)
|
|
1122
|
+
for key in to_delete_to_delete_keys:
|
|
1123
|
+
self._deleted_field.pop(key)
|
|
1124
|
+
else:
|
|
1125
|
+
for body in func_def.body:
|
|
1126
|
+
if not isinstance(body, ast.If):
|
|
1127
|
+
continue
|
|
1128
|
+
test = body.test
|
|
1129
|
+
field_finder = FieldFinder(test)
|
|
1130
|
+
to_delete_to_delete_keys = []
|
|
1131
|
+
for key, _ in self._deleted_field.items():
|
|
1132
|
+
if field_finder.check(key):
|
|
1133
|
+
to_delete_to_delete_keys.append(key)
|
|
1134
|
+
for key in to_delete_to_delete_keys:
|
|
1135
|
+
self._deleted_field.pop(key)
|
|
1112
1136
|
|
|
1113
1137
|
def _remove_unused_field(self):
|
|
1114
1138
|
"""remove unused field in __init__ function"""
|
|
1115
|
-
to_delete_field = {}
|
|
1116
1139
|
multi_targets = []
|
|
1117
1140
|
for index, body in enumerate(self._init_func_ast.body):
|
|
1118
1141
|
if not isinstance(body, ast.Assign):
|
|
@@ -1121,12 +1144,12 @@ class SymbolTree(Observer, Observable):
|
|
|
1121
1144
|
for target in targets:
|
|
1122
1145
|
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
|
|
1123
1146
|
and target.value.id == "self":
|
|
1124
|
-
|
|
1147
|
+
self._deleted_field[target.attr] = index
|
|
1125
1148
|
if len(targets) > 1:
|
|
1126
1149
|
multi_targets.append(index)
|
|
1127
|
-
self._filter_out_to_delete_field(
|
|
1150
|
+
self._filter_out_to_delete_field(self._deleted_field)
|
|
1128
1151
|
for i in range(len(self._init_func_ast.body) - 1, -1, -1):
|
|
1129
|
-
if i in
|
|
1152
|
+
if i in self._deleted_field.values():
|
|
1130
1153
|
if i in multi_targets:
|
|
1131
1154
|
raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
|
|
1132
1155
|
AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
|
|
@@ -1144,12 +1167,9 @@ class SymbolTree(Observer, Observable):
|
|
|
1144
1167
|
self._module_ast.body.remove(body)
|
|
1145
1168
|
|
|
1146
1169
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1147
|
-
if isinstance(node_or_name, Node):
|
|
1148
|
-
result = self.get_node(node_or_name.get_name())
|
|
1149
|
-
return result if result is node_or_name else None
|
|
1150
1170
|
if isinstance(node_or_name, str):
|
|
1151
1171
|
return self.get_node(node_or_name)
|
|
1152
|
-
return
|
|
1172
|
+
return node_or_name
|
|
1153
1173
|
|
|
1154
1174
|
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
|
1155
1175
|
"""
|
|
@@ -1298,7 +1318,7 @@ class SymbolTree(Observer, Observable):
|
|
|
1298
1318
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1299
1319
|
if value.type == ValueType.CustomObjValue:
|
|
1300
1320
|
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
|
|
1301
|
-
self.
|
|
1321
|
+
setattr(self._origin_network, field, value.value)
|
|
1302
1322
|
init_targets = [ScopedValue.create_naming_value(field, "self")]
|
|
1303
1323
|
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
|
|
1304
1324
|
result[arg] = init_targets[0]
|
|
@@ -1316,15 +1336,34 @@ class SymbolTree(Observer, Observable):
|
|
|
1316
1336
|
Returns:
|
|
1317
1337
|
A class handle.
|
|
1318
1338
|
"""
|
|
1319
|
-
|
|
1320
|
-
|
|
1339
|
+
self._update_container()
|
|
1340
|
+
file_path = os.getcwd()
|
|
1341
|
+
file_path = os.path.join(file_path, "rewritten_network")
|
|
1342
|
+
if not os.path.exists(file_path):
|
|
1343
|
+
os.mkdir(file_path)
|
|
1344
|
+
file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
|
|
1345
|
+
network_file = os.path.join(file_path, file_name)
|
|
1346
|
+
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1321
1347
|
source = self.get_code()
|
|
1322
1348
|
f.write(source.encode('utf-8'))
|
|
1323
1349
|
f.flush()
|
|
1324
|
-
|
|
1350
|
+
os.fsync(f)
|
|
1351
|
+
tmp_module_path, tmp_module_file = os.path.split(network_file)
|
|
1325
1352
|
tmp_module_name = tmp_module_file[:-3]
|
|
1326
1353
|
sys.path.append(tmp_module_path)
|
|
1327
|
-
tmp_module =
|
|
1354
|
+
tmp_module = None
|
|
1355
|
+
|
|
1356
|
+
i = 0
|
|
1357
|
+
while not tmp_module:
|
|
1358
|
+
try:
|
|
1359
|
+
tmp_module = importlib.import_module(tmp_module_name)
|
|
1360
|
+
except ModuleNotFoundError:
|
|
1361
|
+
if i > 10:
|
|
1362
|
+
break
|
|
1363
|
+
time.sleep(0.1)
|
|
1364
|
+
i += 1
|
|
1365
|
+
if not tmp_module:
|
|
1366
|
+
logger.error(f"load module {tmp_module_name} failed.")
|
|
1328
1367
|
network_cls = getattr(tmp_module, self._opt_cls_name)
|
|
1329
1368
|
if network_cls is None:
|
|
1330
1369
|
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
|
@@ -1333,3 +1372,87 @@ class SymbolTree(Observer, Observable):
|
|
|
1333
1372
|
def _on_change(self, event: Event):
|
|
1334
1373
|
self._modified = True
|
|
1335
1374
|
self.changed(event)
|
|
1375
|
+
|
|
1376
|
+
def _update_container(self):
|
|
1377
|
+
"""Update instance of node in container."""
|
|
1378
|
+
for node in self.nodes():
|
|
1379
|
+
index = 0
|
|
1380
|
+
if node.get_node_type() == NodeType.CellContainer:
|
|
1381
|
+
for n in node.node_list:
|
|
1382
|
+
if not n.valid:
|
|
1383
|
+
continue
|
|
1384
|
+
if n.get_node_type() == NodeType.Tree:
|
|
1385
|
+
obj = n.symbol_tree.get_network()
|
|
1386
|
+
node.get_instance()[index] = obj
|
|
1387
|
+
else:
|
|
1388
|
+
node.get_instance()[index] = n.get_instance()
|
|
1389
|
+
index += 1
|
|
1390
|
+
|
|
1391
|
+
def _cal_difference_set(self, input, other):
|
|
1392
|
+
"""Calculate different set of two sets."""
|
|
1393
|
+
set1 = set(input)
|
|
1394
|
+
set2 = set(other)
|
|
1395
|
+
return set1 - set2
|
|
1396
|
+
|
|
1397
|
+
def _merge_origin_property(self, new_net):
|
|
1398
|
+
"""Merge property of two network."""
|
|
1399
|
+
tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
|
|
1400
|
+
new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
|
|
1401
|
+
for name in new_attr_names:
|
|
1402
|
+
setattr(new_net, name, getattr(self._origin_network, name))
|
|
1403
|
+
# merger cells
|
|
1404
|
+
cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
|
|
1405
|
+
cells = self._cal_difference_set(cells, self._deleted_node)
|
|
1406
|
+
for c in cells:
|
|
1407
|
+
new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
|
|
1408
|
+
# merge primitives
|
|
1409
|
+
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1410
|
+
for p in primitives:
|
|
1411
|
+
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1412
|
+
|
|
1413
|
+
def _update_names_for_unique(self, node: ast.AST):
|
|
1414
|
+
""" Update names of ast nodes for unique. """
|
|
1415
|
+
if isinstance(node, (ast.For, ast.If, ast.While)):
|
|
1416
|
+
self._update_names_for_unique_branchs(node)
|
|
1417
|
+
elif isinstance(node, ast.Assign):
|
|
1418
|
+
self._update_names_for_unique(node.value)
|
|
1419
|
+
for target in node.targets:
|
|
1420
|
+
self._update_names_for_unique(target)
|
|
1421
|
+
elif isinstance(node, ast.Call):
|
|
1422
|
+
if isinstance(node.func, ast.Attribute):
|
|
1423
|
+
self._update_names_for_unique(node.func.value)
|
|
1424
|
+
for arg in node.args:
|
|
1425
|
+
self._update_names_for_unique(arg)
|
|
1426
|
+
for keyword in node.keywords:
|
|
1427
|
+
self._update_names_for_unique(keyword)
|
|
1428
|
+
elif isinstance(node, ast.UnaryOp):
|
|
1429
|
+
self._update_names_for_unique(node.operand)
|
|
1430
|
+
elif isinstance(node, ast.BinOp):
|
|
1431
|
+
self._update_names_for_unique(node.left)
|
|
1432
|
+
self._update_names_for_unique(node.right)
|
|
1433
|
+
elif isinstance(node, (ast.Attribute, ast.Subscript, ast.Return)):
|
|
1434
|
+
self._update_names_for_unique(node.value)
|
|
1435
|
+
elif isinstance(node, (ast.List, ast.Tuple)):
|
|
1436
|
+
for elt in node.elts:
|
|
1437
|
+
self._update_names_for_unique(elt)
|
|
1438
|
+
elif isinstance(node, ast.Compare):
|
|
1439
|
+
for comparator in node.comparators:
|
|
1440
|
+
self._update_names_for_unique(comparator)
|
|
1441
|
+
elif isinstance(node, ast.Name):
|
|
1442
|
+
node.id = self._target_namer.get_real_arg(node.id)
|
|
1443
|
+
|
|
1444
|
+
def _update_names_for_unique_branchs(self, node: Union[ast.For, ast.If, ast.While]):
|
|
1445
|
+
""" Update names of ast nodes for unique with ast.For, ast.If or ast.While """
|
|
1446
|
+
if isinstance(node, ast.For):
|
|
1447
|
+
self._update_names_for_unique(node.target)
|
|
1448
|
+
self._update_names_for_unique(node.iter)
|
|
1449
|
+
for body in node.body:
|
|
1450
|
+
self._update_names_for_unique(body)
|
|
1451
|
+
for body in node.orelse:
|
|
1452
|
+
self._update_names_for_unique(body)
|
|
1453
|
+
elif isinstance(node, (ast.If, ast.While)):
|
|
1454
|
+
self._update_names_for_unique(node.test)
|
|
1455
|
+
for body in node.body:
|
|
1456
|
+
self._update_names_for_unique(body)
|
|
1457
|
+
for body in node.orelse:
|
|
1458
|
+
self._update_names_for_unique(body)
|
|
@@ -28,6 +28,41 @@ from .ast_helpers import AstModifier
|
|
|
28
28
|
from .ast_helpers import AstFinder
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class FunctionSymbolTreeBuilder:
|
|
32
|
+
"""Create function SymbolTree"""
|
|
33
|
+
def __init__(self, network: Cell, ast_root):
|
|
34
|
+
self._origin_net = network
|
|
35
|
+
self._ast_root: ast.Module = ast_root
|
|
36
|
+
self._root_tree: Optional[SymbolTree] = None
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _ast_transform(ast_root: ast.AST) -> ast.AST:
|
|
40
|
+
"""
|
|
41
|
+
Optimize ast before parse.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
ast_root (ast.AST): An instance of ast to be optimized.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
An instance of ast been optimized.
|
|
48
|
+
"""
|
|
49
|
+
transform_list = [FlattenRecursiveStmt()]
|
|
50
|
+
for transformer in transform_list:
|
|
51
|
+
ast_root = transformer.transform(ast_root)
|
|
52
|
+
return ast_root
|
|
53
|
+
|
|
54
|
+
def build(self) -> SymbolTree:
|
|
55
|
+
"""
|
|
56
|
+
Build SymbolTree.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
An instance of SymbolTree.
|
|
60
|
+
"""
|
|
61
|
+
self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
|
|
62
|
+
self._root_tree.finish_build()
|
|
63
|
+
return self._root_tree
|
|
64
|
+
|
|
65
|
+
|
|
31
66
|
class SymbolTreeBuilder:
|
|
32
67
|
"""
|
|
33
68
|
`SymbolTreeBuilder` for building a SymbolTree from network.
|
|
@@ -43,6 +78,8 @@ class SymbolTreeBuilder:
|
|
|
43
78
|
network_str = inspect.getsource(type(network))
|
|
44
79
|
self._ast_root: ast.Module = ast.parse(network_str)
|
|
45
80
|
self._root_tree: Optional[SymbolTree] = None
|
|
81
|
+
if isinstance(network, Cell) and network.jit_config_dict:
|
|
82
|
+
self._jit_config_dict = network.jit_config_dict
|
|
46
83
|
|
|
47
84
|
@staticmethod
|
|
48
85
|
def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree):
|
|
@@ -140,7 +177,7 @@ class SymbolTreeBuilder:
|
|
|
140
177
|
"""
|
|
141
178
|
|
|
142
179
|
for node in self._root_tree.nodes():
|
|
143
|
-
if isinstance(node, TreeNode):
|
|
180
|
+
if isinstance(node, TreeNode) and node.get_instance():
|
|
144
181
|
SymbolTreeBuilder.merge_module_of_subtree(self._root_tree, node.symbol_tree)
|
|
145
182
|
|
|
146
183
|
def _reduce_redundant_import(self):
|