mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/rewrite/__init__.py
CHANGED
|
@@ -33,6 +33,8 @@ from .api.node import Node
|
|
|
33
33
|
from .api.node_type import NodeType
|
|
34
34
|
from .api.pattern_engine import PatternEngine, PatternNode, VarNode, Replacement
|
|
35
35
|
from .api.tree_node_helper import TreeNodeHelper
|
|
36
|
+
from .sparsify.sparsify import sparsify
|
|
37
|
+
from .sparsify.utils import ArgType, SparseFunc
|
|
36
38
|
|
|
37
39
|
__all__ = ["SymbolTree", "Node", "NodeType", "ScopedValue", "ValueType", "PatternEngine", "PatternNode", "VarNode",
|
|
38
|
-
"Replacement", "TreeNodeHelper"]
|
|
40
|
+
"Replacement", "TreeNodeHelper", "sparsify", "ArgType", "SparseFunc"]
|
mindspore/rewrite/api/node.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import Union, Optional
|
|
|
18
18
|
|
|
19
19
|
from mindspore.nn import Cell
|
|
20
20
|
from mindspore.ops.primitive import Primitive
|
|
21
|
-
from
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
22
|
from ..node import Node as NodeImpl
|
|
23
23
|
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
|
24
24
|
from .node_type import NodeType
|
|
@@ -99,12 +99,6 @@ class Node:
|
|
|
99
99
|
args, kwargs, name, is_sub_net))
|
|
100
100
|
|
|
101
101
|
def get_handler(self) -> NodeImpl:
|
|
102
|
-
"""
|
|
103
|
-
Get handler of node implementation.
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
An instance of `NodeImpl`.
|
|
107
|
-
"""
|
|
108
102
|
return self._node
|
|
109
103
|
|
|
110
104
|
def get_inputs(self) -> ['Node']:
|
|
@@ -181,7 +175,6 @@ class Node:
|
|
|
181
175
|
|
|
182
176
|
Raises:
|
|
183
177
|
RuntimeError: If `src_node` is not belong to current `SymbolTree`.
|
|
184
|
-
RuntimeError: If current node and `src_node` is not belong to same `SymbolTree`.
|
|
185
178
|
TypeError: If `arg_idx` is not a `int` number.
|
|
186
179
|
ValueError: If `arg_idx` is out of range.
|
|
187
180
|
TypeError: If `src_node` is not a `Node` instance.
|
|
@@ -209,27 +202,6 @@ class Node:
|
|
|
209
202
|
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
|
|
210
203
|
|
|
211
204
|
def get_targets(self) -> [ScopedValue]:
|
|
212
|
-
"""
|
|
213
|
-
Get targets of current node.
|
|
214
|
-
|
|
215
|
-
- When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
|
|
216
|
-
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
|
|
217
|
-
ast.Assign.
|
|
218
|
-
- When node_type of current node is Input, `targets` should have only one element which is a string represents
|
|
219
|
-
parameter of function.
|
|
220
|
-
- When node_type of current node is `Python` or `Output`, `targets` are don't-care.
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
A list of instances of ScopedValue as targets of node.
|
|
224
|
-
|
|
225
|
-
Examples:
|
|
226
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
227
|
-
>>> from lenet import Lenet
|
|
228
|
-
>>> net = Lenet()
|
|
229
|
-
>>> stree = SymbolTree.create(net)
|
|
230
|
-
>>> node = stree.get_node("conv1")
|
|
231
|
-
>>> targets = node.get_targets()
|
|
232
|
-
"""
|
|
233
205
|
return self._node.get_targets()
|
|
234
206
|
|
|
235
207
|
def get_name(self) -> str:
|
|
@@ -284,106 +256,21 @@ class Node:
|
|
|
284
256
|
return self._node.get_instance_type()
|
|
285
257
|
|
|
286
258
|
def get_instance(self):
|
|
287
|
-
"""
|
|
288
|
-
Get the instance of current node.
|
|
289
|
-
|
|
290
|
-
- When node_type of current node is `CallCell`, instance is an instance of Cell.
|
|
291
|
-
- When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
|
|
292
|
-
- When node_type of current node is `Tree`, instance is an instance of network-cell.
|
|
293
|
-
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
|
|
294
|
-
|
|
295
|
-
Returns:
|
|
296
|
-
A object represents corresponding instance of current node.
|
|
297
|
-
"""
|
|
298
259
|
return self._node.get_instance()
|
|
299
260
|
|
|
300
261
|
def get_args(self) -> [ScopedValue]:
|
|
301
|
-
"""
|
|
302
|
-
Get the arguments of current node.
|
|
303
|
-
|
|
304
|
-
- When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
|
|
305
|
-
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
|
306
|
-
- When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
|
|
307
|
-
- When `node_type` of current node is `Output`, arguments represents the return values of network.
|
|
308
|
-
- When `node_type` of current node is `Python`, arguments are don't-care.
|
|
309
|
-
|
|
310
|
-
Returns:
|
|
311
|
-
A list of instances of `ScopedValue`.
|
|
312
|
-
|
|
313
|
-
Examples:
|
|
314
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
315
|
-
>>> from lenet import Lenet
|
|
316
|
-
>>> net = Lenet()
|
|
317
|
-
>>> stree = SymbolTree.create(net)
|
|
318
|
-
>>> node = stree.get_node("conv1")
|
|
319
|
-
>>> args = node.get_args()
|
|
320
|
-
"""
|
|
321
262
|
return self._node.get_args()
|
|
322
263
|
|
|
323
264
|
def get_kwargs(self) -> {str: ScopedValue}:
|
|
324
|
-
"""
|
|
325
|
-
Get the keyword arguments of current node.
|
|
326
|
-
|
|
327
|
-
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
|
|
328
|
-
to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
|
329
|
-
- When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
A dict of str to instance of `ScopedValue`.
|
|
333
|
-
|
|
334
|
-
Examples:
|
|
335
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
336
|
-
>>> from lenet import Lenet
|
|
337
|
-
>>> net = Lenet()
|
|
338
|
-
>>> stree = SymbolTree.create(net)
|
|
339
|
-
>>> node = stree.get_node("conv1")
|
|
340
|
-
>>> kwargs = node.get_kwargs()
|
|
341
|
-
"""
|
|
342
265
|
return self._node.get_kwargs()
|
|
343
266
|
|
|
344
267
|
def set_attribute(self, key: str, value):
|
|
345
|
-
"""
|
|
346
|
-
Set attribute of current node.
|
|
347
|
-
|
|
348
|
-
Args:
|
|
349
|
-
key (str): Key of attribute.
|
|
350
|
-
value (object): Value of attribute.
|
|
351
|
-
|
|
352
|
-
Raises:
|
|
353
|
-
TypeError: If `key` is not a `str`.
|
|
354
|
-
|
|
355
|
-
Examples:
|
|
356
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
357
|
-
>>> from lenet import Lenet
|
|
358
|
-
>>> net = Lenet()
|
|
359
|
-
>>> stree = SymbolTree.create(net)
|
|
360
|
-
>>> node = stree.get_node("conv1")
|
|
361
|
-
>>> node.set_attribute("channel", 3)
|
|
362
|
-
"""
|
|
363
268
|
Validator.check_value_type("key", key, [str], "Node attribute")
|
|
364
269
|
self._node.set_attribute(key, value)
|
|
365
270
|
|
|
366
271
|
def get_attributes(self) -> {str: object}:
|
|
367
|
-
"""
|
|
368
|
-
Get all attributes of current node.
|
|
369
|
-
|
|
370
|
-
Returns:
|
|
371
|
-
A dict of str to instance of object as attributes.
|
|
372
|
-
"""
|
|
373
272
|
return self._node.get_attributes()
|
|
374
273
|
|
|
375
274
|
def get_attribute(self, key: str):
|
|
376
|
-
"""
|
|
377
|
-
Get attribute of current node by key.
|
|
378
|
-
|
|
379
|
-
Args:
|
|
380
|
-
key (str): Key of attribute.
|
|
381
|
-
|
|
382
|
-
Returns:
|
|
383
|
-
A object as attribute, can be any type.
|
|
384
|
-
|
|
385
|
-
Raises:
|
|
386
|
-
TypeError: If `key` is not a `str`.
|
|
387
|
-
"""
|
|
388
275
|
Validator.check_value_type("key", key, [str], "Node attribute")
|
|
389
276
|
return self._node.get_attribute(key)
|
|
@@ -29,6 +29,7 @@ class NodeType(Enum):
|
|
|
29
29
|
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
|
|
30
30
|
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
|
|
31
31
|
- Tree: `Tree` node represents sub-network invoking in forward method.
|
|
32
|
+
- MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
|
|
32
33
|
|
|
33
34
|
"""
|
|
34
35
|
Unknown = 0
|
|
@@ -43,3 +44,5 @@ class NodeType(Enum):
|
|
|
43
44
|
Input = 7
|
|
44
45
|
Output = 8
|
|
45
46
|
Tree = 9
|
|
47
|
+
CellContainer = 10
|
|
48
|
+
MathOps = 11
|
|
@@ -20,7 +20,7 @@ import abc
|
|
|
20
20
|
from mindspore.nn import Cell
|
|
21
21
|
from mindspore.ops.primitive import Primitive
|
|
22
22
|
from mindspore import log as logger
|
|
23
|
-
from
|
|
23
|
+
from mindspore import _checkparam as Validator
|
|
24
24
|
from .node_type import NodeType
|
|
25
25
|
from .node import Node
|
|
26
26
|
from .symbol_tree import SymbolTree
|
|
@@ -308,6 +308,16 @@ class PatternEngine:
|
|
|
308
308
|
queue.extend(inputs_dict.get(cur_node.get_name()))
|
|
309
309
|
return new_root
|
|
310
310
|
|
|
311
|
+
@staticmethod
|
|
312
|
+
def _multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes):
|
|
313
|
+
"""Replace node in CellContainer."""
|
|
314
|
+
to_erase_list = list(matched_dict.values())
|
|
315
|
+
stree.replace(Node(node), new_nodes)
|
|
316
|
+
for n in reversed(to_erase_list):
|
|
317
|
+
if n.get_handler() is node:
|
|
318
|
+
continue
|
|
319
|
+
stree.erase_node(n)
|
|
320
|
+
|
|
311
321
|
def apply(self, stree: SymbolTree) -> bool:
|
|
312
322
|
"""
|
|
313
323
|
Apply current pattern to a `SymbolTree`.
|
|
@@ -359,6 +369,9 @@ class PatternEngine:
|
|
|
359
369
|
visited.append(cur_node)
|
|
360
370
|
queue.extend(cur_node.get_users())
|
|
361
371
|
continue
|
|
372
|
+
if cur_node.get_node_type() == NodeType.CellContainer:
|
|
373
|
+
self._process_cellcontainer(stree, cur_node.get_handler())
|
|
374
|
+
continue
|
|
362
375
|
visited.append(cur_node)
|
|
363
376
|
matched, matched_dict = self._match(self._pattern, cur_node)
|
|
364
377
|
# not matched
|
|
@@ -460,3 +473,20 @@ class PatternEngine:
|
|
|
460
473
|
logger.debug("Check match failed, pattern leaked")
|
|
461
474
|
return False
|
|
462
475
|
return True
|
|
476
|
+
|
|
477
|
+
def _process_cellcontainer(self, stree, cellcontainer):
|
|
478
|
+
"""Process CellContainer node."""
|
|
479
|
+
for node in cellcontainer.nodes():
|
|
480
|
+
if node.get_node_type() == NodeType.Tree:
|
|
481
|
+
subtree = node.symbol_tree
|
|
482
|
+
self.apply(SymbolTree(subtree))
|
|
483
|
+
continue
|
|
484
|
+
matched, matched_dict = self._match(self._pattern, Node(node))
|
|
485
|
+
if not matched:
|
|
486
|
+
continue
|
|
487
|
+
new_nodes = []
|
|
488
|
+
if self._replacement is not None:
|
|
489
|
+
new_nodes = self._replacement(self._pattern, self._is_chain, matched_dict)
|
|
490
|
+
if not new_nodes: # if replacement is empty, do nothing
|
|
491
|
+
continue
|
|
492
|
+
PatternEngine._multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Rewrite module api: ValueType and ScopedValue."""
|
|
16
16
|
from enum import Enum
|
|
17
17
|
from typing import Optional, Union
|
|
18
|
-
from
|
|
18
|
+
from mindspore import _checkparam as Validator
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ValueType(Enum):
|
|
@@ -127,8 +127,8 @@ class ScopedValue:
|
|
|
127
127
|
Create a list of naming `ScopedValue`.
|
|
128
128
|
|
|
129
129
|
Args:
|
|
130
|
-
names
|
|
131
|
-
scopes
|
|
130
|
+
names (list[str] or tuple[str]): List or tuple of `str` represents names of referenced variables.
|
|
131
|
+
scopes (list[str] or tuple[str]): List or tuple of `str` represents scopes of referenced variables.
|
|
132
132
|
|
|
133
133
|
Returns:
|
|
134
134
|
An list of instance of `ScopedValue`.
|
|
@@ -140,7 +140,7 @@ class ScopedValue:
|
|
|
140
140
|
|
|
141
141
|
Examples:
|
|
142
142
|
>>> from mindspore.rewrite import ScopedValue
|
|
143
|
-
>>> variables = ScopedValue.create_name_values(["z", "z_1"]
|
|
143
|
+
>>> variables = ScopedValue.create_name_values(["z", "z_1"], name="subnet")
|
|
144
144
|
"""
|
|
145
145
|
Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue")
|
|
146
146
|
if scopes is not None:
|
|
@@ -18,7 +18,7 @@ from types import FunctionType
|
|
|
18
18
|
import mindspore as ms
|
|
19
19
|
|
|
20
20
|
from mindspore.nn import Cell
|
|
21
|
-
from
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
22
|
from .node import Node
|
|
23
23
|
from ..symbol_tree_builder import SymbolTreeBuilder
|
|
24
24
|
from ..symbol_tree import Position, SymbolTree as SymbolTreeImpl
|
|
@@ -70,40 +70,7 @@ class SymbolTree:
|
|
|
70
70
|
if v not in MsDtypes and not isinstance(v, ParamTypes):
|
|
71
71
|
raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
|
|
72
72
|
|
|
73
|
-
def create_call_function(self, func, targets, *args, **kwargs):
|
|
74
|
-
r"""
|
|
75
|
-
Create a Node object and generate the execution code to insert into the source code.
|
|
76
|
-
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
func (FunctionType): The function to be called.
|
|
80
|
-
targets (list[str]): indicates the output name. As the output of the node in the source code.
|
|
81
|
-
args (Union[MsDtypes, ParamTypes]): parameter name of the node. Used as a parameter to a code statement in
|
|
82
|
-
source code. The default value is None, which means there is no parameter input in the cell.
|
|
83
|
-
kwargs (dict{str,Union[MsDtypes, ParamTypes]}): The key type must be str,
|
|
84
|
-
and the value must be value or type must be ParamTypes.
|
|
85
|
-
The input parameter name used to describe the formal parameter with a keyword.
|
|
86
|
-
Enter the name in the source code as the 'kwargs' in the statement expression.The default value is
|
|
87
|
-
None, which means there is no 'kwargs' input.
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
An instance of `Node`.
|
|
91
|
-
|
|
92
|
-
Raises:
|
|
93
|
-
TypeError: If `func` is not FunctionType.
|
|
94
|
-
TypeError: If `targets` is not `list`.
|
|
95
|
-
TypeError: If the type of `targets` is not str.
|
|
96
|
-
TypeError: If arg in `args` is not ParamType.
|
|
97
|
-
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
|
|
98
|
-
|
|
99
|
-
Examples:
|
|
100
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
101
|
-
>>> from lenet import Lenet
|
|
102
|
-
>>> net = Lenet()
|
|
103
|
-
>>> stree = SymbolTree.create(net)
|
|
104
|
-
>>> node = stree.get_node("conv1")
|
|
105
|
-
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
|
|
106
|
-
"""
|
|
73
|
+
def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
|
|
107
74
|
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
|
|
108
75
|
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
|
|
109
76
|
args_ = list(args)
|
|
@@ -115,22 +82,9 @@ class SymbolTree:
|
|
|
115
82
|
for key, value in kwargs.items():
|
|
116
83
|
if isinstance(value, Node):
|
|
117
84
|
kwargs[key] = value.get_handler()
|
|
118
|
-
return Node(self._symbol_tree.
|
|
85
|
+
return Node(self._symbol_tree._create_call_function(func, targets, args_, kwargs)) # pylint: disable=W0212
|
|
119
86
|
|
|
120
87
|
def get_handler(self) -> SymbolTreeImpl:
|
|
121
|
-
"""
|
|
122
|
-
Get handler of `SymbolTree` implementation.
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
An instance of `SymbolTree`.
|
|
126
|
-
|
|
127
|
-
Examples:
|
|
128
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
129
|
-
>>> from lenet import Lenet
|
|
130
|
-
>>> net = Lenet()
|
|
131
|
-
>>> stree = SymbolTree.create(net)
|
|
132
|
-
>>> handler = stree.get_handler()
|
|
133
|
-
"""
|
|
134
88
|
return self._symbol_tree
|
|
135
89
|
|
|
136
90
|
def nodes(self):
|
|
@@ -152,25 +106,6 @@ class SymbolTree:
|
|
|
152
106
|
yield Node(node)
|
|
153
107
|
|
|
154
108
|
def get_node(self, node_name: str) -> Optional[Node]:
|
|
155
|
-
"""
|
|
156
|
-
Get node by `node_name`.
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
node_name (str): A string represents name of node.
|
|
160
|
-
|
|
161
|
-
Returns:
|
|
162
|
-
An instance of node if find else None.
|
|
163
|
-
|
|
164
|
-
Raises:
|
|
165
|
-
TypeError: If `node_name` is not `str`.
|
|
166
|
-
|
|
167
|
-
Examples:
|
|
168
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
169
|
-
>>> from lenet import Lenet
|
|
170
|
-
>>> net = Lenet()
|
|
171
|
-
>>> stree = SymbolTree.create(net)
|
|
172
|
-
>>> node = stree.get_node("conv1")
|
|
173
|
-
"""
|
|
174
109
|
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
|
|
175
110
|
node_impl = self._symbol_tree.get_node(node_name)
|
|
176
111
|
if node_impl is None:
|
|
@@ -354,16 +289,6 @@ class SymbolTree:
|
|
|
354
289
|
self._symbol_tree.dump()
|
|
355
290
|
|
|
356
291
|
def print_node_tabulate(self):
|
|
357
|
-
"""
|
|
358
|
-
Print node information of graph.
|
|
359
|
-
|
|
360
|
-
Examples:
|
|
361
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
362
|
-
>>> from lenet import Lenet
|
|
363
|
-
>>> net = Lenet()
|
|
364
|
-
>>> stree = SymbolTree.create(net)
|
|
365
|
-
>>> stree.print_node_tabulate()
|
|
366
|
-
"""
|
|
367
292
|
self._symbol_tree.print_node_tabulate()
|
|
368
293
|
|
|
369
294
|
def get_code(self) -> str:
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from typing import Optional
|
|
17
17
|
|
|
18
18
|
from mindspore import log as logger
|
|
19
|
-
from
|
|
19
|
+
from mindspore import _checkparam as Validator
|
|
20
20
|
from .symbol_tree import SymbolTree
|
|
21
21
|
from .node import Node
|
|
22
22
|
from .node_type import NodeType
|
|
@@ -17,11 +17,11 @@
|
|
|
17
17
|
Define some ast helpers for manipulating python ast.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from .ast_finder import AstFinder, StrChecker,
|
|
20
|
+
from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj
|
|
21
21
|
from .ast_replacer import AstReplacer
|
|
22
22
|
from .ast_modifier import AstModifier
|
|
23
23
|
from .ast_creator import ast_args_creator, ast_assign_creator, ast_attributer_creator, ast_call_creator, \
|
|
24
24
|
ast_create_arg_value, ast_index_creator, ast_keyword_creator, ast_kwargs_creator, ast_name_creator, \
|
|
25
25
|
ast_num_creator, ast_str_creator, ast_subscript_creator
|
|
26
26
|
|
|
27
|
-
__all__ = ["AstFinder", "AstReplacer", "AstModifier", "StrChecker"]
|
|
27
|
+
__all__ = ["AstFinder", "AstReplacer", "AstModifier", "StrChecker", "CheckPropertyIsUsed", "GetPropertyOfObj"]
|
|
@@ -49,14 +49,13 @@ def ast_call_creator(func: ast.AST, args: list, keywords: list):
|
|
|
49
49
|
|
|
50
50
|
def ast_create_arg_value(value):
|
|
51
51
|
"""Create arg node by type."""
|
|
52
|
-
from mindspore.rewrite.node import Node
|
|
53
52
|
if isinstance(value, (int, float)):
|
|
54
53
|
ast_value = ast_num_creator(value)
|
|
55
54
|
elif isinstance(value, str):
|
|
56
55
|
ast_value = ast_str_creator(value)
|
|
57
56
|
elif value in (ms.float16, ms.float32, ms.float64):
|
|
58
57
|
ast_value = ast_attributer_creator(".".join(["mindspore", str(value).lower()]))
|
|
59
|
-
elif isinstance(value, Node):
|
|
58
|
+
elif isinstance(value, ms.rewrite.node.Node):
|
|
60
59
|
ast_value = ast_str_creator(value.get_targets()[0])
|
|
61
60
|
else:
|
|
62
61
|
raise TypeError("Unsupported arg type: ", type(value))
|
|
@@ -160,3 +160,68 @@ class FindConstValueInInit(ast.NodeVisitor):
|
|
|
160
160
|
self._hit = False
|
|
161
161
|
self.generic_visit(self._context)
|
|
162
162
|
return self._hit
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class CheckPropertyIsUsed(ast.NodeVisitor):
|
|
166
|
+
"""
|
|
167
|
+
Check whether a property is used.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
node (ast.AST): An instance of ast node.
|
|
171
|
+
"""
|
|
172
|
+
def __init__(self, node: ast.AST):
|
|
173
|
+
self._context = node
|
|
174
|
+
self._value = ""
|
|
175
|
+
self._attr = ""
|
|
176
|
+
self._hit = False
|
|
177
|
+
|
|
178
|
+
def visit_Attribute(self, node: ast.Attribute) -> Any: # pylint: disable=invalid-name
|
|
179
|
+
"""Visit a node of type ast.Attribute."""
|
|
180
|
+
if isinstance(node.value, ast.Name) and node.value.id == self._value and node.attr == self._attr:
|
|
181
|
+
self._hit = True
|
|
182
|
+
return super(CheckPropertyIsUsed, self).generic_visit(node)
|
|
183
|
+
|
|
184
|
+
def generic_visit(self, node: ast.AST) -> Any:
|
|
185
|
+
"""
|
|
186
|
+
An override method, iterating over all nodes and save target ast nodes.
|
|
187
|
+
"""
|
|
188
|
+
if self._hit:
|
|
189
|
+
return
|
|
190
|
+
super(CheckPropertyIsUsed, self).generic_visit(node)
|
|
191
|
+
|
|
192
|
+
def check(self, value, attr) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
Check whether `value` and `attr` exists.
|
|
195
|
+
"""
|
|
196
|
+
self._value = value
|
|
197
|
+
self._attr = attr
|
|
198
|
+
self._hit = False
|
|
199
|
+
self.generic_visit(self._context)
|
|
200
|
+
return self._hit
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class GetPropertyOfObj(ast.NodeVisitor):
|
|
204
|
+
"""
|
|
205
|
+
Check whether a property is used.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
node (ast.AST): An instance of ast node.
|
|
209
|
+
"""
|
|
210
|
+
def __init__(self, node: ast.AST):
|
|
211
|
+
self._context = node
|
|
212
|
+
self._property = set()
|
|
213
|
+
|
|
214
|
+
def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name
|
|
215
|
+
"""Visit a node of type ast.Attribute."""
|
|
216
|
+
target = node.targets[0]
|
|
217
|
+
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == "self":
|
|
218
|
+
self._property.add(target.attr)
|
|
219
|
+
return super(GetPropertyOfObj, self).generic_visit(node)
|
|
220
|
+
|
|
221
|
+
def get(self):
|
|
222
|
+
"""
|
|
223
|
+
Check whether `value` and `attr` exists.
|
|
224
|
+
"""
|
|
225
|
+
self._property = set()
|
|
226
|
+
self.generic_visit(self._context)
|
|
227
|
+
return self._property
|
|
@@ -241,8 +241,10 @@ class AstModifier(ast.NodeTransformer):
|
|
|
241
241
|
An instance of ast.Assign which has been appended to 'init_func'.
|
|
242
242
|
"""
|
|
243
243
|
return AstModifier.insert_assign_to_function(init_func, targets=targets,
|
|
244
|
-
|
|
245
|
-
|
|
244
|
+
expr=ScopedValue(ValueType.NamingValue, "", "setattr"),
|
|
245
|
+
args=[ScopedValue(ValueType.NamingValue, "obj"),
|
|
246
|
+
ScopedValue.create_variable_value(field)])
|
|
247
|
+
|
|
246
248
|
|
|
247
249
|
@staticmethod
|
|
248
250
|
def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],
|
|
@@ -459,7 +461,7 @@ class AstModifier(ast.NodeTransformer):
|
|
|
459
461
|
|
|
460
462
|
Args:
|
|
461
463
|
src_argument (ScopedValue): An instance of ScopedValue represents new argument.
|
|
462
|
-
dst_ast (ast.AST):
|
|
464
|
+
dst_ast (ast.AST): Ast node to be updated by ScopedValue.
|
|
463
465
|
|
|
464
466
|
Raises:
|
|
465
467
|
TypeError: Input src_argument is not a ScopedValue
|
|
@@ -490,6 +492,12 @@ class AstModifier(ast.NodeTransformer):
|
|
|
490
492
|
str(src_argument.type))
|
|
491
493
|
dst_ast.n = src_argument.value
|
|
492
494
|
return
|
|
495
|
+
if isinstance(dst_ast, ast.Str):
|
|
496
|
+
if src_argument.type not in [ValueType.StringValue]:
|
|
497
|
+
raise RuntimeError("src_argument should be a StringValue, but got:",
|
|
498
|
+
str(src_argument.type))
|
|
499
|
+
dst_ast.s = src_argument.value
|
|
500
|
+
return
|
|
493
501
|
if isinstance(dst_ast, ast.Name):
|
|
494
502
|
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
|
|
495
503
|
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from typing import Any, Tuple
|
|
18
18
|
import ast
|
|
19
19
|
from ast import FunctionDef
|
|
20
|
+
import astunparse
|
|
20
21
|
|
|
21
22
|
from mindspore import log as logger
|
|
22
23
|
from ..common import error_str
|
|
@@ -37,7 +38,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
37
38
|
ast.Call: ["args"],
|
|
38
39
|
ast.BinOp: ["left", "right"],
|
|
39
40
|
ast.BoolOp: ["values"],
|
|
40
|
-
ast.
|
|
41
|
+
ast.UnaryOp: ["operand"],
|
|
42
|
+
ast.Compare: ["left", "comparators"],
|
|
41
43
|
}
|
|
42
44
|
|
|
43
45
|
@staticmethod
|
|
@@ -54,7 +56,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
54
56
|
target_name = "function"
|
|
55
57
|
elif isinstance(node, ast.Return):
|
|
56
58
|
target_name = "return_value"
|
|
57
|
-
elif isinstance(node, (ast.BinOp, ast.
|
|
59
|
+
elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
|
|
58
60
|
target_name = type(node.op).__name__.lower() + "_var"
|
|
59
61
|
elif isinstance(node, ast.Tuple):
|
|
60
62
|
target_name = type(node).__name__.lower() + "_var"
|
|
@@ -180,6 +182,20 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
180
182
|
child = node.body[index]
|
|
181
183
|
if isinstance(child, ast.Assign):
|
|
182
184
|
stmt = child.value
|
|
185
|
+
elif isinstance(child, ast.If):
|
|
186
|
+
if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
|
|
187
|
+
if isinstance(child.body[0].value, ast.Call):
|
|
188
|
+
if_body = child.body
|
|
189
|
+
if_func = if_body[0].value
|
|
190
|
+
expr = "x = " + astunparse.unparse(if_func)
|
|
191
|
+
if_body = ast.parse(expr)
|
|
192
|
+
if_body = if_body.body+ast.parse("return x").body
|
|
193
|
+
child.body = if_body
|
|
194
|
+
stmt = child
|
|
195
|
+
else:
|
|
196
|
+
stmt = child
|
|
197
|
+
else:
|
|
198
|
+
stmt = child
|
|
183
199
|
elif isinstance(child, ast.Expr):
|
|
184
200
|
stmt = child.value
|
|
185
201
|
else:
|
mindspore/rewrite/namespace.py
CHANGED
|
@@ -24,8 +24,6 @@ _ms_functional_ns = CellNamespace('mindspore.ops.functional')
|
|
|
24
24
|
|
|
25
25
|
def is_subtree(cls_name):
|
|
26
26
|
"""Determine whether 'cls_name' is a subtree."""
|
|
27
|
-
if cls_name == "SequentialCell":
|
|
28
|
-
return True
|
|
29
27
|
if cls_name == "QuantizeWrapperCell":
|
|
30
28
|
return False
|
|
31
29
|
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
|